Search in sources :

Example 1 with SynchronizedUnivariateSolver

use of org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.

/**
     * E-step update of the sample-specific unexplained variance
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
     * number of function evaluations per sample (key: "iterations")
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
    cacheWorkers("after E-step for sample unexplained variance initialization");
    /* create a compound objective function for simultaneous multi-sample queries */
    final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
        if (arg.isEmpty()) {
            return Collections.emptyMap();
        }
        final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
        final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
        final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
        final Map<Integer, Double> output = new HashMap<>();
        IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
        return output;
    };
    final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
    /* instantiate a synchronized multi-sample root finder and add jobs */
    final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
    IntStream.range(0, numSamples).forEach(si -> {
        final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
        syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
    });
    /* solve and collect statistics */
    final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
    final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
    try {
        final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
        newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
            final int sampleIndex = entry.getKey();
            final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
            double val = 0;
            switch(summary.status) {
                case SUCCESS:
                    val = summary.x;
                    break;
                case TOO_MANY_EVALUATIONS:
                    logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
                    break;
            }
            newSampleUnexplainedVariance.put(sampleIndex, 0, val);
            numberOfEvaluations.add(summary.evaluations);
        });
    } catch (final InterruptedException ex) {
        throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
    }
    /* admix */
    final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
    /* calculate the error */
    final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
    /* update local copy */
    sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
    /* push to workers */
    pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
    final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications)

Example 2 with SynchronizedUnivariateSolver

use of org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.

/**
     * E-step update of the sample-specific unexplained variance
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
     * number of function evaluations per sample (key: "iterations")
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
    cacheWorkers("after E-step for sample unexplained variance initialization");
    /* create a compound objective function for simultaneous multi-sample queries */
    final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
        if (arg.isEmpty()) {
            return Collections.emptyMap();
        }
        final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
        final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
        final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
        final Map<Integer, Double> output = new HashMap<>();
        IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
        return output;
    };
    final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
    /* instantiate a synchronized multi-sample root finder and add jobs */
    final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
    IntStream.range(0, numSamples).forEach(si -> {
        final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
        syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
    });
    /* solve and collect statistics */
    final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
    final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
    try {
        final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
        newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
            final int sampleIndex = entry.getKey();
            final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
            double val = 0;
            switch(summary.status) {
                case SUCCESS:
                    val = summary.x;
                    break;
                case TOO_MANY_EVALUATIONS:
                    logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
                    break;
            }
            newSampleUnexplainedVariance.put(sampleIndex, 0, val);
            numberOfEvaluations.add(summary.evaluations);
        });
    } catch (final InterruptedException ex) {
        throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
    }
    /* admix */
    final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
    /* calculate the error */
    final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
    /* update local copy */
    sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
    /* push to workers */
    pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
    final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications)

Aggregations

VisibleForTesting (com.google.common.annotations.VisibleForTesting)2 Sets (com.google.common.collect.Sets)2 File (java.io.File)2 IOException (java.io.IOException)2 java.util (java.util)2 BiFunction (java.util.function.BiFunction)2 Predicate (java.util.function.Predicate)2 Collectors (java.util.stream.Collectors)2 IntStream (java.util.stream.IntStream)2 Stream (java.util.stream.Stream)2 Nonnull (javax.annotation.Nonnull)2 Nullable (javax.annotation.Nullable)2 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)2 ImmutableTriple (org.apache.commons.lang3.tuple.ImmutableTriple)2 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)2 AbstractUnivariateSolver (org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver)2 NoBracketingException (org.apache.commons.math3.exception.NoBracketingException)2 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)2 FastMath (org.apache.commons.math3.util.FastMath)2 LogManager (org.apache.logging.log4j.LogManager)2