Search in sources :

Example 1 with Gamma

use of org.apache.commons.math3.special.Gamma in project GDSC-SMLM by aherbert.

the class EMGainAnalysis method fit.

/**
	 * Fit the EM-gain distribution (Gaussian * Gamma)
	 * 
	 * @param h
	 *            The distribution
	 */
private void fit(int[] h) {
    final int[] limits = limits(h);
    final double[] x = getX(limits);
    final double[] y = getY(h, limits);
    Plot2 plot = new Plot2(TITLE, "ADU", "Frequency");
    double yMax = Maths.max(y);
    plot.setLimits(limits[0], limits[1], 0, yMax);
    plot.setColor(Color.black);
    plot.addPoints(x, y, Plot2.DOT);
    Utils.display(TITLE, plot);
    // Estimate remaining parameters. 
    // Assuming a gamma_distribution(shape,scale) then mean = shape * scale
    // scale = gain
    // shape = Photons = mean / gain
    double mean = getMean(h) - bias;
    // Note: if the bias is too high then the mean will be negative. Just move the bias.
    while (mean < 0) {
        bias -= 1;
        mean += 1;
    }
    double photons = mean / gain;
    if (simulate)
        Utils.log("Simulated bias=%d, gain=%s, noise=%s, photons=%s", (int) _bias, Utils.rounded(_gain), Utils.rounded(_noise), Utils.rounded(_photons));
    Utils.log("Estimate bias=%d, gain=%s, noise=%s, photons=%s", (int) bias, Utils.rounded(gain), Utils.rounded(noise), Utils.rounded(photons));
    final int max = (int) x[x.length - 1];
    double[] g = pdf(max, photons, gain, noise, (int) bias);
    plot.setColor(Color.blue);
    plot.addPoints(x, g, Plot2.LINE);
    Utils.display(TITLE, plot);
    // Perform a fit
    CustomPowellOptimizer o = new CustomPowellOptimizer(1e-6, 1e-16, 1e-6, 1e-16);
    double[] startPoint = new double[] { photons, gain, noise, bias };
    int maxEval = 3000;
    String[] paramNames = { "Photons", "Gain", "Noise", "Bias" };
    // Set bounds
    double[] lower = new double[] { 0, 0.5 * gain, 0, bias - noise };
    double[] upper = new double[] { 2 * photons, 2 * gain, gain, bias + noise };
    // Restart until converged.
    // TODO - Maybe fix this with a better optimiser. This needs to be tested on real data.
    PointValuePair solution = null;
    for (int iter = 0; iter < 3; iter++) {
        IJ.showStatus("Fitting histogram ... Iteration " + iter);
        try {
            // Basic Powell optimiser
            MultivariateFunction fun = getFunction(limits, y, max, maxEval);
            PointValuePair optimum = o.optimize(new MaxEval(maxEval), new ObjectiveFunction(fun), GoalType.MINIMIZE, new InitialGuess((solution == null) ? startPoint : solution.getPointRef()));
            if (solution == null || optimum.getValue() < solution.getValue()) {
                double[] point = optimum.getPointRef();
                // Check the bounds
                for (int i = 0; i < point.length; i++) {
                    if (point[i] < lower[i] || point[i] > upper[i]) {
                        throw new RuntimeException(String.format("Fit out of of estimated range: %s %f", paramNames[i], point[i]));
                    }
                }
                solution = optimum;
            }
        } catch (Exception e) {
            IJ.log("Powell error: " + e.getMessage());
            if (e instanceof TooManyEvaluationsException) {
                maxEval = (int) (maxEval * 1.5);
            }
        }
        try {
            // Bounded Powell optimiser
            MultivariateFunction fun = getFunction(limits, y, max, maxEval);
            MultivariateFunctionMappingAdapter adapter = new MultivariateFunctionMappingAdapter(fun, lower, upper);
            PointValuePair optimum = o.optimize(new MaxEval(maxEval), new ObjectiveFunction(adapter), GoalType.MINIMIZE, new InitialGuess(adapter.boundedToUnbounded((solution == null) ? startPoint : solution.getPointRef())));
            double[] point = adapter.unboundedToBounded(optimum.getPointRef());
            optimum = new PointValuePair(point, optimum.getValue());
            if (solution == null || optimum.getValue() < solution.getValue()) {
                solution = optimum;
            }
        } catch (Exception e) {
            IJ.log("Bounded Powell error: " + e.getMessage());
            if (e instanceof TooManyEvaluationsException) {
                maxEval = (int) (maxEval * 1.5);
            }
        }
    }
    IJ.showStatus("");
    IJ.showProgress(1);
    if (solution == null) {
        Utils.log("Failed to fit the distribution");
        return;
    }
    double[] point = solution.getPointRef();
    photons = point[0];
    gain = point[1];
    noise = point[2];
    bias = (int) Math.round(point[3]);
    String label = String.format("Fitted bias=%d, gain=%s, noise=%s, photons=%s", (int) bias, Utils.rounded(gain), Utils.rounded(noise), Utils.rounded(photons));
    Utils.log(label);
    if (simulate) {
        Utils.log("Relative Error bias=%s, gain=%s, noise=%s, photons=%s", Utils.rounded(relativeError(bias, _bias)), Utils.rounded(relativeError(gain, _gain)), Utils.rounded(relativeError(noise, _noise)), Utils.rounded(relativeError(photons, _photons)));
    }
    // Show the PoissonGammaGaussian approximation
    double[] f = null;
    if (showApproximation) {
        f = new double[x.length];
        PoissonGammaGaussianFunction fun = new PoissonGammaGaussianFunction(1.0 / gain, noise);
        final double expected = photons * gain;
        for (int i = 0; i < f.length; i++) {
            f[i] = fun.likelihood(x[i] - bias, expected);
        //System.out.printf("x=%d, g=%f, f=%f, error=%f\n", (int) x[i], g[i], f[i],
        //		gdsc.smlm.fitting.utils.DoubleEquality.relativeError(g[i], f[i]));
        }
        yMax = Maths.maxDefault(yMax, f);
    }
    // Replot
    g = pdf(max, photons, gain, noise, (int) bias);
    plot = new Plot2(TITLE, "ADU", "Frequency");
    plot.setLimits(limits[0], limits[1], 0, yMax * 1.05);
    plot.setColor(Color.black);
    plot.addPoints(x, y, Plot2.DOT);
    plot.setColor(Color.red);
    plot.addPoints(x, g, Plot2.LINE);
    plot.addLabel(0, 0, label);
    if (showApproximation) {
        plot.setColor(Color.blue);
        plot.addPoints(x, f, Plot2.LINE);
    }
    Utils.display(TITLE, plot);
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) Plot2(ij.gui.Plot2) Point(java.awt.Point) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) PointValuePair(org.apache.commons.math3.optim.PointValuePair) MultivariateFunction(org.apache.commons.math3.analysis.MultivariateFunction) MultivariateFunctionMappingAdapter(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) CustomPowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer) PoissonGammaGaussianFunction(gdsc.smlm.function.PoissonGammaGaussianFunction)

Example 2 with Gamma

use of org.apache.commons.math3.special.Gamma in project GDSC-SMLM by aherbert.

the class PoissonGammaGaussianFunction method likelihood.

/**
	 * Compute the likelihood
	 * 
	 * @param o
	 *            The observed count
	 * @param e
	 *            The expected count
	 * @return The likelihood
	 */
public double likelihood(final double o, final double e) {
    // Use the same variables as the Python code
    final double cij = o;
    // convert to photons
    final double eta = alpha * e;
    if (sigma == 0) {
        // No convolution with a Gaussian. Simply evaluate for a Poisson-Gamma distribution.
        final double p;
        // Any observed count above zero
        if (cij > 0.0) {
            // The observed count converted to photons
            final double nij = alpha * cij;
            // The limit on eta * nij is therefore (709/2)^2 = 125670.25
            if (eta * nij > 10000) {
                // Approximate Bessel function i1(x) when using large x:
                // i1(x) ~ exp(x)/sqrt(2*pi*x)
                // However the entire equation is logged (creating transform),
                // evaluated then raised to e to prevent overflow error on 
                // large exp(x)
                final double transform = 0.5 * Math.log(alpha * eta / cij) - nij - eta + 2 * Math.sqrt(eta * nij) - Math.log(twoSqrtPi * Math.pow(eta * nij, 0.25));
                p = FastMath.exp(transform);
            } else {
                // Second part of equation 135
                p = Math.sqrt(alpha * eta / cij) * FastMath.exp(-nij - eta) * Bessel.I1(2 * Math.sqrt(eta * nij));
            }
        } else if (cij == 0.0) {
            p = FastMath.exp(-eta);
        } else {
            p = 0;
        }
        return (p > minimumProbability) ? p : minimumProbability;
    } else if (useApproximation) {
        return mortensenApproximation(cij, eta);
    } else {
        // This code is the full evaluation of equation 7 from the supplementary information  
        // of the paper Chao, et al (2013) Nature Methods 10, 335-338.
        // It is the full evaluation of a Poisson-Gamma-Gaussian convolution PMF. 
        // Read noise
        final double sk = sigma;
        // Gain
        final double g = 1.0 / alpha;
        // Observed pixel value
        final double z = o;
        // Expected number of photons
        final double vk = eta;
        // Compute the integral to infinity of:
        // exp( -((z-u)/(sqrt(2)*s)).^2 - u/g ) * sqrt(vk*u/g) .* besseli(1, 2 * sqrt(vk*u/g)) ./ u;
        // vk / g
        final double vk_g = vk * alpha;
        final double sqrt2sigma = Math.sqrt(2) * sk;
        // Specify the function to integrate
        UnivariateFunction f = new UnivariateFunction() {

            public double value(double u) {
                return eval(sqrt2sigma, z, vk_g, g, u);
            }
        };
        // Integrate to infinity is not necessary. The convolution of the function with the 
        // Gaussian should be adequately sampled using a nxSD around the maximum.
        // Find a bracket containing the maximum
        double lower, upper;
        double maxU = Math.max(1, o);
        double rLower = maxU;
        double rUpper = maxU + 1;
        double f1 = f.value(rLower);
        double f2 = f.value(rUpper);
        // Calculate the simple integral and the range
        double sum = f1 + f2;
        boolean searchUp = f2 > f1;
        if (searchUp) {
            while (f2 > f1) {
                f1 = f2;
                rUpper += 1;
                f2 = f.value(rUpper);
                sum += f2;
            }
            maxU = rUpper - 1;
        } else {
            // Ensure that u stays above zero
            while (f1 > f2 && rLower > 1) {
                f2 = f1;
                rLower -= 1;
                f1 = f.value(rLower);
                sum += f1;
            }
            maxU = (rLower > 1) ? rLower + 1 : rLower;
        }
        lower = Math.max(0, maxU - 5 * sk);
        upper = maxU + 5 * sk;
        if (useSimpleIntegration && lower > 0) {
            // remaining points in the range
            for (double u = rLower - 1; u >= lower; u -= 1) {
                sum += f.value(u);
            }
            for (double u = rUpper + 1; u <= upper; u += 1) {
                sum += f.value(u);
            }
        } else {
            // Use Legendre-Gauss integrator
            try {
                final double relativeAccuracy = 1e-4;
                final double absoluteAccuracy = 1e-8;
                final int minimalIterationCount = 3;
                final int maximalIterationCount = 32;
                final int integrationPoints = 16;
                // Use an integrator that does not use the boundary since u=0 is undefined (divide by zero)
                UnivariateIntegrator i = new IterativeLegendreGaussIntegrator(integrationPoints, relativeAccuracy, absoluteAccuracy, minimalIterationCount, maximalIterationCount);
                sum = i.integrate(2000, f, lower, upper);
            } catch (TooManyEvaluationsException ex) {
                return mortensenApproximation(cij, eta);
            }
        }
        // Compute the final probability
        //final double 
        f1 = z / sqrt2sigma;
        final double p = (FastMath.exp(-vk) / (sqrt2pi * sk)) * (FastMath.exp(-(f1 * f1)) + sum);
        return (p > minimumProbability) ? p : minimumProbability;
    }
}
Also used : IterativeLegendreGaussIntegrator(org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) UnivariateIntegrator(org.apache.commons.math3.analysis.integration.UnivariateIntegrator)

Example 3 with Gamma

use of org.apache.commons.math3.special.Gamma in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.

/**
     * E-step update of the sample-specific unexplained variance
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
     * number of function evaluations per sample (key: "iterations")
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
    cacheWorkers("after E-step for sample unexplained variance initialization");
    /* create a compound objective function for simultaneous multi-sample queries */
    final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
        if (arg.isEmpty()) {
            return Collections.emptyMap();
        }
        final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
        final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
        final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
        final Map<Integer, Double> output = new HashMap<>();
        IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
        return output;
    };
    final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
    /* instantiate a synchronized multi-sample root finder and add jobs */
    final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
    IntStream.range(0, numSamples).forEach(si -> {
        final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
        syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
    });
    /* solve and collect statistics */
    final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
    final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
    try {
        final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
        newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
            final int sampleIndex = entry.getKey();
            final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
            double val = 0;
            switch(summary.status) {
                case SUCCESS:
                    val = summary.x;
                    break;
                case TOO_MANY_EVALUATIONS:
                    logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
                    break;
            }
            newSampleUnexplainedVariance.put(sampleIndex, 0, val);
            numberOfEvaluations.add(summary.evaluations);
        });
    } catch (final InterruptedException ex) {
        throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
    }
    /* admix */
    final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
    /* calculate the error */
    final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
    /* update local copy */
    sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
    /* push to workers */
    pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
    final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications)

Example 4 with Gamma

use of org.apache.commons.math3.special.Gamma in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.

/**
     * E-step update of the sample-specific unexplained variance
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
     * number of function evaluations per sample (key: "iterations")
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
    mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
    cacheWorkers("after E-step for sample unexplained variance initialization");
    /* create a compound objective function for simultaneous multi-sample queries */
    final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
        if (arg.isEmpty()) {
            return Collections.emptyMap();
        }
        final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
        final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
        final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
        final Map<Integer, Double> output = new HashMap<>();
        IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
        return output;
    };
    final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
    /* instantiate a synchronized multi-sample root finder and add jobs */
    final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
    IntStream.range(0, numSamples).forEach(si -> {
        final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
        syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
    });
    /* solve and collect statistics */
    final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
    final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
    try {
        final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
        newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
            final int sampleIndex = entry.getKey();
            final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
            double val = 0;
            switch(summary.status) {
                case SUCCESS:
                    val = summary.x;
                    break;
                case TOO_MANY_EVALUATIONS:
                    logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
                    break;
            }
            newSampleUnexplainedVariance.put(sampleIndex, 0, val);
            numberOfEvaluations.add(summary.evaluations);
        });
    } catch (final InterruptedException ex) {
        throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
    }
    /* admix */
    final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
    /* calculate the error */
    final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
    /* update local copy */
    sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
    /* push to workers */
    pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
    final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications)

Example 5 with Gamma

use of org.apache.commons.math3.special.Gamma in project gatk-protected by broadinstitute.

the class AlleleFractionSegmenterUnitTest method generateCounts.

//visible for testing joint segmentation
protected static AllelicCountCollection generateCounts(final List<Double> minorAlleleFractionSequence, final List<SimpleInterval> positions, final RandomGenerator rng, final AlleleFractionGlobalParameters trueParams) {
    //translate to ApacheCommons' parametrization of the gamma distribution
    final GammaDistribution biasGenerator = getGammaDistribution(trueParams, rng);
    final double outlierProbability = trueParams.getOutlierProbability();
    final AllelicCountCollection counts = new AllelicCountCollection();
    for (int n = 0; n < minorAlleleFractionSequence.size(); n++) {
        counts.add(generateAllelicCount(minorAlleleFractionSequence.get(n), positions.get(n), rng, biasGenerator, outlierProbability));
    }
    return counts;
}
Also used : AllelicCountCollection(org.broadinstitute.hellbender.tools.exome.alleliccount.AllelicCountCollection) GammaDistribution(org.apache.commons.math3.distribution.GammaDistribution)

Aggregations

TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)8 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)5 GammaDistribution (org.apache.commons.math3.distribution.GammaDistribution)5 MaxEval (org.apache.commons.math3.optim.MaxEval)4 Plot (ij.gui.Plot)3 InitialGuess (org.apache.commons.math3.optim.InitialGuess)3 PointValuePair (org.apache.commons.math3.optim.PointValuePair)3 MultivariateFunctionMappingAdapter (org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter)3 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)3 VisibleForTesting (com.google.common.annotations.VisibleForTesting)2 Sets (com.google.common.collect.Sets)2 Point (java.awt.Point)2 File (java.io.File)2 IOException (java.io.IOException)2 java.util (java.util)2 BiFunction (java.util.function.BiFunction)2 Predicate (java.util.function.Predicate)2 Collectors (java.util.stream.Collectors)2 IntStream (java.util.stream.IntStream)2 Stream (java.util.stream.Stream)2