Search in sources :

Example 1 with NoBracketingException

use of org.apache.commons.math3.exception.NoBracketingException in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.

/**
     * M-step update of unexplained variance in the isotropic mode
     *
     * @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
     */
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
    cacheWorkers("after M-step update of isotropic unexplained variance initialization");
    final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
    final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
    final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
    final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
    double newIsotropicTargetSpecificVariance;
    try {
        newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
    } catch (NoBracketingException e) {
        logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    } catch (TooManyEvaluationsException e) {
        logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    }
    /* update the compute block(s) */
    final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
    final int maxIterations = solver.getEvaluations();
    final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction)

Example 2 with NoBracketingException

use of org.apache.commons.math3.exception.NoBracketingException in project gatk-protected by broadinstitute.

the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.

/**
     * Performs the M-step for target-specific unexplained variance and clones the compute block
     * with the updated value.
     *
     * @param maxIters maximum number of iterations
     * @param psiUpperLimit upper limit for the unexplained variance
     * @param absTol absolute error tolerance (used in root finding)
     * @param relTol relative error tolerance (used in root finding)
     * @param numBisections number of bisections (used in root finding)
     * @param refinementDepth depth of search (used in root finding)
     *
     * @return a new instance of {@link CoverageModelEMComputeBlock}
     */
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
    Utils.validateArg(maxIters > 0, "At least one iteration is required");
    Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
    Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
    Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
    Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
    Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
    Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
    /* fetch the required caches */
    final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
    final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
    final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
    final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
    final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
    final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
    final List<ImmutablePair<Double, Integer>> res;
    try {
        res = forkJoinPool.submit(() -> {
            return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
                final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
                double newPsi;
                try {
                    newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
                } catch (NoBracketingException | TooManyEvaluationsException e) {
                    newPsi = Psi_t.getDouble(ti);
                }
                return new ImmutablePair<>(newPsi, solver.getEvaluations());
            }).collect(Collectors.toList());
        }).get();
    } catch (InterruptedException | ExecutionException ex) {
        throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
    }
    final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
    final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
    final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
    return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) Map(java.util.Map) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) Nd4j(org.nd4j.linalg.factory.Nd4j) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Serializable(java.io.Serializable) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) org.broadinstitute.hellbender.tools.coveragemodel.cachemanager(org.broadinstitute.hellbender.tools.coveragemodel.cachemanager) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ForkJoinPool(java.util.concurrent.ForkJoinPool) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) Collections(java.util.Collections) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Example 3 with NoBracketingException

use of org.apache.commons.math3.exception.NoBracketingException in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.

/**
     * M-step update of unexplained variance in the isotropic mode
     *
     * @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
     */
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
    cacheWorkers("after M-step update of isotropic unexplained variance initialization");
    final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
    final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
    final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
    final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
    double newIsotropicTargetSpecificVariance;
    try {
        newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
    } catch (NoBracketingException e) {
        logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    } catch (TooManyEvaluationsException e) {
        logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    }
    /* update the compute block(s) */
    final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
    final int maxIterations = solver.getEvaluations();
    final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction)

Example 4 with NoBracketingException

use of org.apache.commons.math3.exception.NoBracketingException in project gatk-protected by broadinstitute.

the class RobustBrentSolver method doSolve.

@Override
protected double doSolve() throws TooManyEvaluationsException, NoBracketingException {
    final double min = getMin();
    final double max = getMax();
    final double[] xSearchGrid = createHybridSearchGrid(min, max, numBisections, depth);
    final double[] fSearchGrid = Arrays.stream(xSearchGrid).map(this::computeObjectiveValue).toArray();
    /* find bracketing intervals on the search grid */
    final List<Bracket> bracketsList = detectBrackets(xSearchGrid, fSearchGrid);
    if (bracketsList.isEmpty()) {
        throw new NoBracketingException(min, max, fSearchGrid[0], fSearchGrid[fSearchGrid.length - 1]);
    }
    final BrentSolver solver = new BrentSolver(getRelativeAccuracy(), getAbsoluteAccuracy(), getFunctionValueAccuracy());
    final List<Double> roots = bracketsList.stream().map(b -> solver.solve(getMaxEvaluations(), this::computeObjectiveValue, b.min, b.max, 0.5 * (b.min + b.max))).collect(Collectors.toList());
    if (roots.size() == 1 || meritFunc == null) {
        return roots.get(0);
    }
    final double[] merits = roots.stream().mapToDouble(meritFunc::value).toArray();
    final int bestRootIndex = IntStream.range(0, roots.size()).boxed().max((i, j) -> (int) (merits[i] - merits[j])).get();
    return roots.get(bestRootIndex);
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) BrentSolver(org.apache.commons.math3.analysis.solvers.BrentSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) ArrayList(java.util.ArrayList) List(java.util.List) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Nullable(javax.annotation.Nullable) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) BrentSolver(org.apache.commons.math3.analysis.solvers.BrentSolver)

Example 5 with NoBracketingException

use of org.apache.commons.math3.exception.NoBracketingException in project gatk-protected by broadinstitute.

the class RobustBrentSolverUnitTest method simpleTest.

/**
     * Test on a 4th degree polynomial with 4 real roots at x = 0, 1, 2, 3. This objective function is positive for
     * large enough positive and negative values of its arguments. Therefore, the simple Brent solver complains that
     * the search interval does not bracket a root. The robust Brent solver, however, subdivides the given search
     * interval and finds a bracketing sub-interval.
     *
     * The "best" root according to the given merit function (set to the anti-derivative of the objective function)
     * is in fact the one at x = 0. We require the robust solver to output x = 0, and the simple solver to fail.
     */
@Test
public void simpleTest() {
    final UnivariateFunction objFunc = x -> 30 * x * (x - 1) * (x - 2) * (x - 3);
    final UnivariateFunction meritFunc = x -> 6 * FastMath.pow(x, 5) - 45 * FastMath.pow(x, 4) + 110 * FastMath.pow(x, 3) - 90 * FastMath.pow(x, 2);
    final RobustBrentSolver solverRobust = new RobustBrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC, meritFunc, 4, 1);
    final BrentSolver solverSimple = new BrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC);
    final double xRobust = solverRobust.solve(100, objFunc, -1, 4);
    Assert.assertEquals(xRobust, 0, DEF_ABS_ACC);
    boolean simpleSolverFails = false;
    try {
        /* this will fail */
        solverSimple.solve(100, objFunc, -1, 4);
    } catch (final NoBracketingException ex) {
        simpleSolverFails = true;
    }
    Assert.assertTrue(simpleSolverFails);
}
Also used : List(java.util.List) Assert(org.testng.Assert) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) FastMath(org.apache.commons.math3.util.FastMath) Test(org.testng.annotations.Test) BrentSolver(org.apache.commons.math3.analysis.solvers.BrentSolver) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) BrentSolver(org.apache.commons.math3.analysis.solvers.BrentSolver) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Aggregations

UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)8 NoBracketingException (org.apache.commons.math3.exception.NoBracketingException)8 FastMath (org.apache.commons.math3.util.FastMath)8 VisibleForTesting (com.google.common.annotations.VisibleForTesting)6 List (java.util.List)6 Collectors (java.util.stream.Collectors)6 IntStream (java.util.stream.IntStream)6 Nullable (javax.annotation.Nullable)6 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)6 Utils (org.broadinstitute.hellbender.utils.Utils)6 Arrays (java.util.Arrays)4 Nonnull (javax.annotation.Nonnull)4 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)4 ImmutableTriple (org.apache.commons.lang3.tuple.ImmutableTriple)4 AbstractUnivariateSolver (org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver)4 BrentSolver (org.apache.commons.math3.analysis.solvers.BrentSolver)4 RobustBrentSolver (org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver)4 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 Nd4j (org.nd4j.linalg.factory.Nd4j)4