Search in sources :

Example 6 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk-protected by broadinstitute.

the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.

/**
     * Performs the M-step for target-specific unexplained variance and clones the compute block
     * with the updated value.
     *
     * @param maxIters maximum number of iterations
     * @param psiUpperLimit upper limit for the unexplained variance
     * @param absTol absolute error tolerance (used in root finding)
     * @param relTol relative error tolerance (used in root finding)
     * @param numBisections number of bisections (used in root finding)
     * @param refinementDepth depth of search (used in root finding)
     *
     * @return a new instance of {@link CoverageModelEMComputeBlock}
     */
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
    Utils.validateArg(maxIters > 0, "At least one iteration is required");
    Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
    Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
    Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
    Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
    Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
    Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
    /* fetch the required caches */
    final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
    final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
    final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
    final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
    final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
    final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
    final List<ImmutablePair<Double, Integer>> res;
    try {
        res = forkJoinPool.submit(() -> {
            return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
                final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
                double newPsi;
                try {
                    newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
                } catch (NoBracketingException | TooManyEvaluationsException e) {
                    newPsi = Psi_t.getDouble(ti);
                }
                return new ImmutablePair<>(newPsi, solver.getEvaluations());
            }).collect(Collectors.toList());
        }).get();
    } catch (InterruptedException | ExecutionException ex) {
        throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
    }
    final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
    final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
    final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
    return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) Map(java.util.Map) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) Nd4j(org.nd4j.linalg.factory.Nd4j) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Serializable(java.io.Serializable) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) org.broadinstitute.hellbender.tools.coveragemodel.cachemanager(org.broadinstitute.hellbender.tools.coveragemodel.cachemanager) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ForkJoinPool(java.util.concurrent.ForkJoinPool) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) Collections(java.util.Collections) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Example 7 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.

/**
     * M-step update of unexplained variance in the isotropic mode
     *
     * @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
     */
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
    cacheWorkers("after M-step update of isotropic unexplained variance initialization");
    final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
    final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
    final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
    final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
    double newIsotropicTargetSpecificVariance;
    try {
        newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
    } catch (NoBracketingException e) {
        logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    } catch (TooManyEvaluationsException e) {
        logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
        newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
    }
    /* update the compute block(s) */
    final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
    final int maxIterations = solver.getEvaluations();
    final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction)

Example 8 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk-protected by broadinstitute.

the class RobustBrentSolverUnitTest method simpleTest.

/**
     * Test on a 4th degree polynomial with 4 real roots at x = 0, 1, 2, 3. This objective function is positive for
     * large enough positive and negative values of its arguments. Therefore, the simple Brent solver complains that
     * the search interval does not bracket a root. The robust Brent solver, however, subdivides the given search
     * interval and finds a bracketing sub-interval.
     *
     * The "best" root according to the given merit function (set to the anti-derivative of the objective function)
     * is in fact the one at x = 0. We require the robust solver to output x = 0, and the simple solver to fail.
     */
@Test
public void simpleTest() {
    final UnivariateFunction objFunc = x -> 30 * x * (x - 1) * (x - 2) * (x - 3);
    final UnivariateFunction meritFunc = x -> 6 * FastMath.pow(x, 5) - 45 * FastMath.pow(x, 4) + 110 * FastMath.pow(x, 3) - 90 * FastMath.pow(x, 2);
    final RobustBrentSolver solverRobust = new RobustBrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC, meritFunc, 4, 1);
    final BrentSolver solverSimple = new BrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC);
    final double xRobust = solverRobust.solve(100, objFunc, -1, 4);
    Assert.assertEquals(xRobust, 0, DEF_ABS_ACC);
    boolean simpleSolverFails = false;
    try {
        /* this will fail */
        solverSimple.solve(100, objFunc, -1, 4);
    } catch (final NoBracketingException ex) {
        simpleSolverFails = true;
    }
    Assert.assertTrue(simpleSolverFails);
}
Also used : List(java.util.List) Assert(org.testng.Assert) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) FastMath(org.apache.commons.math3.util.FastMath) Test(org.testng.annotations.Test) BrentSolver(org.apache.commons.math3.analysis.solvers.BrentSolver) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) BrentSolver(org.apache.commons.math3.analysis.solvers.BrentSolver) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 9 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk by broadinstitute.

the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.

/**
     * Performs the M-step for target-specific unexplained variance and clones the compute block
     * with the updated value.
     *
     * @param maxIters maximum number of iterations
     * @param psiUpperLimit upper limit for the unexplained variance
     * @param absTol absolute error tolerance (used in root finding)
     * @param relTol relative error tolerance (used in root finding)
     * @param numBisections number of bisections (used in root finding)
     * @param refinementDepth depth of search (used in root finding)
     *
     * @return a new instance of {@link CoverageModelEMComputeBlock}
     */
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
    Utils.validateArg(maxIters > 0, "At least one iteration is required");
    Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
    Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
    Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
    Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
    Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
    Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
    /* fetch the required caches */
    final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
    final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
    final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
    final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
    final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
    final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
    final List<ImmutablePair<Double, Integer>> res;
    try {
        res = forkJoinPool.submit(() -> {
            return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
                final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
                double newPsi;
                try {
                    newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
                } catch (NoBracketingException | TooManyEvaluationsException e) {
                    newPsi = Psi_t.getDouble(ti);
                }
                return new ImmutablePair<>(newPsi, solver.getEvaluations());
            }).collect(Collectors.toList());
        }).get();
    } catch (InterruptedException | ExecutionException ex) {
        throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
    }
    final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
    final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
    final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
    return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) Map(java.util.Map) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) Nd4j(org.nd4j.linalg.factory.Nd4j) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Serializable(java.io.Serializable) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) org.broadinstitute.hellbender.tools.coveragemodel.cachemanager(org.broadinstitute.hellbender.tools.coveragemodel.cachemanager) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ForkJoinPool(java.util.concurrent.ForkJoinPool) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) Collections(java.util.Collections) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Example 10 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project GDSC-SMLM by aherbert.

the class PeakResult method computeI1.

/**
	 * Compute the function I1 using numerical integration. See Mortensen, et al (2010) Nature Methods 7, 377-383), SI
	 * equation 43.
	 * 
	 * <pre>
	 * I1 = 1 + sum [ ln(t) / (1 + t/rho) ] dt
	 *    = - sum [ t * ln(t) / (t + rho) ] dt
	 * </pre>
	 * 
	 * Where sum is the integral between 0 and 1. In the case of rho=0 the function returns 1;
	 * 
	 * @param rho
	 * @param integrationPoints
	 *            the number of integration points for the LegendreGaussIntegrator
	 * @return the I1 value
	 */
private static double computeI1(final double rho, int integrationPoints) {
    if (rho == 0)
        return 1;
    final double relativeAccuracy = 1e-4;
    final double absoluteAccuracy = 1e-8;
    final int minimalIterationCount = 3;
    final int maximalIterationCount = 32;
    // Use an integrator that does not use the boundary since log(0) is undefined.
    UnivariateIntegrator i = new IterativeLegendreGaussIntegrator(integrationPoints, relativeAccuracy, absoluteAccuracy, minimalIterationCount, maximalIterationCount);
    // Specify the function to integrate
    UnivariateFunction f = new UnivariateFunction() {

        public double value(double x) {
            return x * Math.log(x) / (x + rho);
        }
    };
    final double i1 = -i.integrate(2000, f, 0, 1);
    return i1;
}
Also used : IterativeLegendreGaussIntegrator(org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator) UnivariateIntegrator(org.apache.commons.math3.analysis.integration.UnivariateIntegrator) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction)

Aggregations

UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)17 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)7 SimpsonIntegrator (org.apache.commons.math3.analysis.integration.SimpsonIntegrator)6 NoBracketingException (org.apache.commons.math3.exception.NoBracketingException)6 FastMath (org.apache.commons.math3.util.FastMath)6 VisibleForTesting (com.google.common.annotations.VisibleForTesting)4 List (java.util.List)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 Nonnull (javax.annotation.Nonnull)4 Nullable (javax.annotation.Nullable)4 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)4 ImmutableTriple (org.apache.commons.lang3.tuple.ImmutableTriple)4 UnivariateIntegrator (org.apache.commons.math3.analysis.integration.UnivariateIntegrator)4 RobustBrentSolver (org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver)4 Utils (org.broadinstitute.hellbender.utils.Utils)4 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 Nd4j (org.nd4j.linalg.factory.Nd4j)4 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)4