Search in sources :

Example 1 with SearchInterval

use of org.apache.commons.math3.optim.univariate.SearchInterval in project GDSC-SMLM by aherbert.

the class FIRE method findMin.

private UnivariatePointValuePair findMin(UnivariatePointValuePair current, UnivariateOptimizer o, UnivariateFunction f, double qValue, double factor) {
    try {
        BracketFinder bracket = new BracketFinder();
        bracket.search(f, GoalType.MINIMIZE, qValue * factor, qValue / factor);
        UnivariatePointValuePair next = o.optimize(GoalType.MINIMIZE, new MaxEval(3000), new SearchInterval(bracket.getLo(), bracket.getHi(), bracket.getMid()), new UnivariateObjectiveFunction(f));
        if (next == null)
            return current;
        //System.out.printf("LineMin [%.1f]  %f = %f\n", factor, next.getPoint(), next.getValue());
        if (current != null)
            return (next.getValue() < current.getValue()) ? next : current;
        return next;
    } catch (Exception e) {
        return current;
    }
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) BracketFinder(org.apache.commons.math3.optim.univariate.BracketFinder) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException)

Example 2 with SearchInterval

use of org.apache.commons.math3.optim.univariate.SearchInterval in project gatk-protected by broadinstitute.

the class PosteriorSummaryUtils method calculatePosteriorMode.

/**
     * Given a list of posterior samples, returns an estimate of the posterior mode (using
     * mllib kernel density estimation in {@link KernelDensity} and {@link BrentOptimizer}).
     * Note that estimate may be poor if number of samples is small (resulting in poor kernel density estimation),
     * or if posterior is not unimodal (or is sufficiently pathological otherwise). If the samples contain
     * {@link Double#NaN}, {@link Double#NaN} will be returned.
     * @param samples   posterior samples, cannot be {@code null} and number of samples must be greater than 0
     * @param ctx       {@link JavaSparkContext} used by {@link KernelDensity} for mllib kernel density estimation
     */
public static double calculatePosteriorMode(final List<Double> samples, final JavaSparkContext ctx) {
    Utils.nonNull(samples);
    Utils.validateArg(samples.size() > 0, "Number of samples must be greater than zero.");
    //calculate sample min, max, mean, and standard deviation
    final double sampleMin = Collections.min(samples);
    final double sampleMax = Collections.max(samples);
    final double sampleMean = new Mean().evaluate(Doubles.toArray(samples));
    final double sampleStandardDeviation = new StandardDeviation().evaluate(Doubles.toArray(samples));
    //if samples are all the same or contain NaN, can simply return mean
    if (sampleStandardDeviation == 0. || Double.isNaN(sampleMean)) {
        return sampleMean;
    }
    //use Silverman's rule to set bandwidth for kernel density estimation from sample standard deviation
    //see https://en.wikipedia.org/wiki/Kernel_density_estimation#Practical_estimation_of_the_bandwidth
    final double bandwidth = SILVERMANS_RULE_CONSTANT * sampleStandardDeviation * Math.pow(samples.size(), SILVERMANS_RULE_EXPONENT);
    //use kernel density estimation to approximate posterior from samples
    final KernelDensity pdf = new KernelDensity().setSample(ctx.parallelize(samples, 1)).setBandwidth(bandwidth);
    //use Brent optimization to find mode (i.e., maximum) of kernel-density-estimated posterior
    final BrentOptimizer optimizer = new BrentOptimizer(RELATIVE_TOLERANCE, RELATIVE_TOLERANCE * (sampleMax - sampleMin));
    final UnivariateObjectiveFunction objective = new UnivariateObjectiveFunction(f -> pdf.estimate(new double[] { f })[0]);
    //search for mode within sample range, start near sample mean
    final SearchInterval searchInterval = new SearchInterval(sampleMin, sampleMax, sampleMean);
    return optimizer.optimize(objective, GoalType.MAXIMIZE, searchInterval, BRENT_MAX_EVAL).getPoint();
}
Also used : Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) KernelDensity(org.apache.spark.mllib.stat.KernelDensity) StandardDeviation(org.apache.commons.math3.stat.descriptive.moment.StandardDeviation)

Example 3 with SearchInterval

use of org.apache.commons.math3.optim.univariate.SearchInterval in project gatk-protected by broadinstitute.

the class CNLOHCaller method optimizeIt.

private double optimizeIt(final Function<Double, Double> objectiveFxn, final SearchInterval searchInterval) {
    final MaxEval BRENT_MAX_EVAL = new MaxEval(1000);
    final double RELATIVE_TOLERANCE = 0.001;
    final double ABSOLUTE_TOLERANCE = 0.001;
    final BrentOptimizer OPTIMIZER = new BrentOptimizer(RELATIVE_TOLERANCE, ABSOLUTE_TOLERANCE);
    final UnivariateObjectiveFunction objective = new UnivariateObjectiveFunction(x -> objectiveFxn.apply(x));
    return OPTIMIZER.optimize(objective, GoalType.MAXIMIZE, searchInterval, BRENT_MAX_EVAL).getPoint();
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer)

Example 4 with SearchInterval

use of org.apache.commons.math3.optim.univariate.SearchInterval in project gatk-protected by broadinstitute.

the class CNLOHCaller method calcNewRhos.

private double[] calcNewRhos(final List<ACNVModeledSegment> segments, final List<double[][][]> responsibilitiesBySeg, final double lambda, final double[] rhos, final int[] mVals, final int[] nVals, final JavaSparkContext ctx) {
    // Since, we pass in the entire responsibilities matrix, we need the correct index for each rho.  That, and the
    //  fact that this is a univariate objective function, means we need to create an instance for each rho.  And
    //  then we blast across Spark.
    final List<Pair<? extends Function<Double, Double>, SearchInterval>> objectives = IntStream.range(0, rhos.length).mapToObj(i -> new Pair<>(new Function<Double, Double>() {

        @Override
        public Double apply(Double rho) {
            return calculateESmnObjective(rho, segments, responsibilitiesBySeg, mVals, nVals, lambda, i);
        }
    }, new SearchInterval(0.0, 1.0, rhos[i]))).collect(Collectors.toList());
    final JavaRDD<Pair<? extends Function<Double, Double>, SearchInterval>> objectivesRDD = ctx.parallelize(objectives);
    final List<Double> resultsAsDouble = objectivesRDD.map(objective -> optimizeIt(objective.getFirst(), objective.getSecond())).collect();
    return resultsAsDouble.stream().mapToDouble(Double::doubleValue).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) RealVector(org.apache.commons.math3.linear.RealVector) Function(java.util.function.Function) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) GammaDistribution(org.apache.commons.math3.distribution.GammaDistribution) Gamma(org.apache.commons.math3.special.Gamma) ACNVModeledSegment(org.broadinstitute.hellbender.tools.exome.ACNVModeledSegment) BaseAbstractUnivariateIntegrator(org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) SimpsonIntegrator(org.apache.commons.math3.analysis.integration.SimpsonIntegrator) JavaRDD(org.apache.spark.api.java.JavaRDD) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Pair(org.apache.commons.math3.util.Pair) Collectors(java.util.stream.Collectors) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) Serializable(java.io.Serializable) List(java.util.List) Percentile(org.apache.commons.math3.stat.descriptive.rank.Percentile) Logger(org.apache.logging.log4j.Logger) MathUtils(org.broadinstitute.hellbender.utils.MathUtils) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) Utils(org.broadinstitute.hellbender.utils.Utils) RealMatrix(org.apache.commons.math3.linear.RealMatrix) VisibleForTesting(com.google.common.annotations.VisibleForTesting) HomoSapiensConstants(org.broadinstitute.hellbender.utils.variant.HomoSapiensConstants) MaxEval(org.apache.commons.math3.optim.MaxEval) LogManager(org.apache.logging.log4j.LogManager) Function(java.util.function.Function) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) Pair(org.apache.commons.math3.util.Pair)

Example 5 with SearchInterval

use of org.apache.commons.math3.optim.univariate.SearchInterval in project gatk by broadinstitute.

the class CNLOHCaller method calcNewRhos.

private double[] calcNewRhos(final List<ACNVModeledSegment> segments, final List<double[][][]> responsibilitiesBySeg, final double lambda, final double[] rhos, final int[] mVals, final int[] nVals, final JavaSparkContext ctx) {
    // Since, we pass in the entire responsibilities matrix, we need the correct index for each rho.  That, and the
    //  fact that this is a univariate objective function, means we need to create an instance for each rho.  And
    //  then we blast across Spark.
    final List<Pair<? extends Function<Double, Double>, SearchInterval>> objectives = IntStream.range(0, rhos.length).mapToObj(i -> new Pair<>(new Function<Double, Double>() {

        @Override
        public Double apply(Double rho) {
            return calculateESmnObjective(rho, segments, responsibilitiesBySeg, mVals, nVals, lambda, i);
        }
    }, new SearchInterval(0.0, 1.0, rhos[i]))).collect(Collectors.toList());
    final JavaRDD<Pair<? extends Function<Double, Double>, SearchInterval>> objectivesRDD = ctx.parallelize(objectives);
    final List<Double> resultsAsDouble = objectivesRDD.map(objective -> optimizeIt(objective.getFirst(), objective.getSecond())).collect();
    return resultsAsDouble.stream().mapToDouble(Double::doubleValue).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) RealVector(org.apache.commons.math3.linear.RealVector) Function(java.util.function.Function) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) GammaDistribution(org.apache.commons.math3.distribution.GammaDistribution) Gamma(org.apache.commons.math3.special.Gamma) ACNVModeledSegment(org.broadinstitute.hellbender.tools.exome.ACNVModeledSegment) BaseAbstractUnivariateIntegrator(org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) SimpsonIntegrator(org.apache.commons.math3.analysis.integration.SimpsonIntegrator) JavaRDD(org.apache.spark.api.java.JavaRDD) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Pair(org.apache.commons.math3.util.Pair) Collectors(java.util.stream.Collectors) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) Serializable(java.io.Serializable) List(java.util.List) Percentile(org.apache.commons.math3.stat.descriptive.rank.Percentile) Logger(org.apache.logging.log4j.Logger) MathUtils(org.broadinstitute.hellbender.utils.MathUtils) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) Utils(org.broadinstitute.hellbender.utils.Utils) RealMatrix(org.apache.commons.math3.linear.RealMatrix) VisibleForTesting(com.google.common.annotations.VisibleForTesting) HomoSapiensConstants(org.broadinstitute.hellbender.utils.variant.HomoSapiensConstants) MaxEval(org.apache.commons.math3.optim.MaxEval) LogManager(org.apache.logging.log4j.LogManager) Function(java.util.function.Function) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) Pair(org.apache.commons.math3.util.Pair)

Aggregations

UnivariateObjectiveFunction (org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction)12 MaxEval (org.apache.commons.math3.optim.MaxEval)10 BrentOptimizer (org.apache.commons.math3.optim.univariate.BrentOptimizer)10 SearchInterval (org.apache.commons.math3.optim.univariate.SearchInterval)10 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)4 UnivariatePointValuePair (org.apache.commons.math3.optim.univariate.UnivariatePointValuePair)4 VisibleForTesting (com.google.common.annotations.VisibleForTesting)2 Serializable (java.io.Serializable)2 Arrays (java.util.Arrays)2 List (java.util.List)2 Function (java.util.function.Function)2 Collectors (java.util.stream.Collectors)2 IntStream (java.util.stream.IntStream)2 BaseAbstractUnivariateIntegrator (org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator)2 SimpsonIntegrator (org.apache.commons.math3.analysis.integration.SimpsonIntegrator)2 GammaDistribution (org.apache.commons.math3.distribution.GammaDistribution)2 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)2 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)2 ArrayRealVector (org.apache.commons.math3.linear.ArrayRealVector)2 MatrixUtils (org.apache.commons.math3.linear.MatrixUtils)2