Search in sources :

Example 6 with RobustBrentSolver

use of org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver in project gatk by broadinstitute.

the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.

/**
     * Performs the M-step for target-specific unexplained variance and clones the compute block
     * with the updated value.
     *
     * @param maxIters maximum number of iterations
     * @param psiUpperLimit upper limit for the unexplained variance
     * @param absTol absolute error tolerance (used in root finding)
     * @param relTol relative error tolerance (used in root finding)
     * @param numBisections number of bisections (used in root finding)
     * @param refinementDepth depth of search (used in root finding)
     *
     * @return a new instance of {@link CoverageModelEMComputeBlock}
     */
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
    Utils.validateArg(maxIters > 0, "At least one iteration is required");
    Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
    Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
    Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
    Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
    Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
    Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
    /* fetch the required caches */
    final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
    final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
    final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
    final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
    final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
    final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
    final List<ImmutablePair<Double, Integer>> res;
    try {
        res = forkJoinPool.submit(() -> {
            return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
                final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
                double newPsi;
                try {
                    newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
                } catch (NoBracketingException | TooManyEvaluationsException e) {
                    newPsi = Psi_t.getDouble(ti);
                }
                return new ImmutablePair<>(newPsi, solver.getEvaluations());
            }).collect(Collectors.toList());
        }).get();
    } catch (InterruptedException | ExecutionException ex) {
        throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
    }
    final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
    final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
    final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
    return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) Map(java.util.Map) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) Nd4j(org.nd4j.linalg.factory.Nd4j) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Serializable(java.io.Serializable) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) org.broadinstitute.hellbender.tools.coveragemodel.cachemanager(org.broadinstitute.hellbender.tools.coveragemodel.cachemanager) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ForkJoinPool(java.util.concurrent.ForkJoinPool) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) Collections(java.util.Collections) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Aggregations

VisibleForTesting (com.google.common.annotations.VisibleForTesting)6 Collectors (java.util.stream.Collectors)6 IntStream (java.util.stream.IntStream)6 Nonnull (javax.annotation.Nonnull)6 Nullable (javax.annotation.Nullable)6 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)6 ImmutableTriple (org.apache.commons.lang3.tuple.ImmutableTriple)6 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)6 NoBracketingException (org.apache.commons.math3.exception.NoBracketingException)6 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)6 FastMath (org.apache.commons.math3.util.FastMath)6 RobustBrentSolver (org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver)6 Utils (org.broadinstitute.hellbender.utils.Utils)6 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 Nd4j (org.nd4j.linalg.factory.Nd4j)6 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)6 Transforms (org.nd4j.linalg.ops.transforms.Transforms)6 Sets (com.google.common.collect.Sets)4 File (java.io.File)4