Search in sources :

Example 11 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class GATKProtectedMathUtils method rowStdDevs.

/**
     * Calculates the standard deviation per row from a matrix.
     * @param matrix the input matrix.
     * @return never {@code null}, an array with as many positions as rows in {@code matrix}.
     * @throws IllegalArgumentException if {@code matrix} is {@code null}.
     */
public static double[] rowStdDevs(final RealMatrix matrix) {
    Utils.nonNull(matrix);
    final Variance varianceEvaluator = new Variance();
    return IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> Math.sqrt(varianceEvaluator.evaluate(matrix.getRow(r)))).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Pair(org.apache.commons.math3.util.Pair) MathArrays(org.apache.commons.math3.util.MathArrays) Collection(java.util.Collection) FastMath(org.apache.commons.math3.util.FastMath) EnumeratedDistribution(org.apache.commons.math3.distribution.EnumeratedDistribution) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) FourierLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperator) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) Serializable(java.io.Serializable) List(java.util.List) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Median(org.apache.commons.math3.stat.descriptive.rank.Median) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Nonnull(javax.annotation.Nonnull) Collections(java.util.Collections) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance)

Example 12 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class GATKProtectedMathUtils method rowVariances.

public static double[] rowVariances(final RealMatrix matrix) {
    Utils.nonNull(matrix);
    final Variance varianceEvaluator = new Variance();
    return IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> varianceEvaluator.evaluate(matrix.getRow(r))).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Pair(org.apache.commons.math3.util.Pair) MathArrays(org.apache.commons.math3.util.MathArrays) Collection(java.util.Collection) FastMath(org.apache.commons.math3.util.FastMath) EnumeratedDistribution(org.apache.commons.math3.distribution.EnumeratedDistribution) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) FourierLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperator) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) Serializable(java.io.Serializable) List(java.util.List) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Median(org.apache.commons.math3.stat.descriptive.rank.Median) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Nonnull(javax.annotation.Nonnull) Collections(java.util.Collections) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance)

Example 13 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.

/**
     * E-step update of the sample-specific unexplained variance
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
     * number of function evaluations per sample (key: "iterations")
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
    cacheWorkers("after E-step for sample unexplained variance initialization");
    /* create a compound objective function for simultaneous multi-sample queries */
    final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
        if (arg.isEmpty()) {
            return Collections.emptyMap();
        }
        final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
        final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
        final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
        final Map<Integer, Double> output = new HashMap<>();
        IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
        return output;
    };
    final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
    /* instantiate a synchronized multi-sample root finder and add jobs */
    final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
    IntStream.range(0, numSamples).forEach(si -> {
        final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
        syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
    });
    /* solve and collect statistics */
    final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
    final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
    try {
        final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
        newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
            final int sampleIndex = entry.getKey();
            final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
            double val = 0;
            switch(summary.status) {
                case SUCCESS:
                    val = summary.x;
                    break;
                case TOO_MANY_EVALUATIONS:
                    logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
                    break;
            }
            newSampleUnexplainedVariance.put(sampleIndex, 0, val);
            numberOfEvaluations.add(summary.evaluations);
        });
    } catch (final InterruptedException ex) {
        throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
    }
    /* admix */
    final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
    /* calculate the error */
    final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
    /* update local copy */
    sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
    /* push to workers */
    pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
    final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications)

Example 14 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method initializeWorkersWithPCA.

/**
     * Initialize model parameters by performing PCA.
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
    logger.info("Initializing model parameters using PCA...");
    /* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
    final int numLatents = config.getNumLatents();
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
    if (biasCovariatesEnabled) {
        mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
    }
    /* update read depth without taking into account correction from bias covariates */
    updateReadDepthPosteriorExpectations(1.0, true);
    /* fetch sample covariance matrix */
    final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
    mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
    cacheWorkers("PCA initialization");
    final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
    /* perform eigen-decomposition on the target covariance matrix */
    final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
    /* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
    final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
    /* estimate the isotropic unexplained variance -- see Bishop 12.46 */
    final int residualDim = numTargets - numLatents;
    final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
    logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
    /* estimate bias factors -- see Bishop 12.45 */
    final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
    final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
    for (int li = 0; li < numLatents; li++) {
        final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
        /* calculate [Delta_PCA_st]^T v */
        /* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
        final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
        final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
        final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
        biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
    }
    if (ardEnabled) {
        /* a better estimate of ARD coefficients */
        biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
    }
    final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
    /* clear PCA initialization data from workers */
    mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
    /* push model parameters to workers */
    initializeWorkersWithGivenModel(modelParamsFromPCA);
    /* update bias latent posterior expectations without admixing */
    updateBiasLatentPosteriorExpectations(1.0);
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 15 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.

/**
     * M-step update of unexplained variance in the isotropic mode
     *
     * @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
     */
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
    cacheWorkers("after M-step update of isotropic unexplained variance initialization");
    final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
    final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
    final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
    final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
    double newIsotropicTargetSpecificVariance;
    try {
        newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
    } catch (NoBracketingException e) {
        logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    } catch (TooManyEvaluationsException e) {
        logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    }
    /* update the compute block(s) */
    final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
    final int maxIterations = solver.getEvaluations();
    final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction)

Aggregations

Collectors (java.util.stream.Collectors)24 IntStream (java.util.stream.IntStream)24 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)22 Nonnull (javax.annotation.Nonnull)20 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)18 Variance (org.apache.commons.math3.stat.descriptive.moment.Variance)18 List (java.util.List)16 FastMath (org.apache.commons.math3.util.FastMath)16 Utils (org.broadinstitute.hellbender.utils.Utils)16 INDArray (org.nd4j.linalg.api.ndarray.INDArray)16 Function (java.util.function.Function)15 Arrays (java.util.Arrays)14 Nullable (javax.annotation.Nullable)14 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)14 RealMatrix (org.apache.commons.math3.linear.RealMatrix)14 Logger (org.apache.logging.log4j.Logger)14 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)14 UserException (org.broadinstitute.hellbender.exceptions.UserException)14 Nd4j (org.nd4j.linalg.factory.Nd4j)14 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)14