Search in sources :

Example 6 with RandomGenerator

use of org.apache.commons.math3.random.RandomGenerator in project GDSC-SMLM by aherbert.

the class FilterTest method canCompareMultiFilter2.

@Test
public void canCompareMultiFilter2() {
    RandomGenerator randomGenerator = new Well19937c(System.currentTimeMillis() + System.identityHashCode(this));
    MultiFilter2 f = new MultiFilter2(0, 0, 0, 0, 0, 0, 0);
    for (int i = 1000; i-- > 0; ) {
        MultiFilter2 f1 = (MultiFilter2) f.create(random(f.getNumberOfParameters(), randomGenerator));
        MultiFilter2 f2 = (MultiFilter2) f.create(random(f.getNumberOfParameters(), randomGenerator));
        int e = f1.weakest((Filter) f2);
        int o = f1.weakest(f2);
        Assert.assertEquals(e, o);
    }
}
Also used : Well19937c(org.apache.commons.math3.random.Well19937c) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Test(org.junit.Test)

Example 7 with RandomGenerator

use of org.apache.commons.math3.random.RandomGenerator in project GDSC-SMLM by aherbert.

the class BinomialFitter method fitBinomial.

/**
	 * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a fixed n value and
	 * attempts to optimise p.
	 * 
	 * @param histogram
	 *            The input histogram
	 * @param mean
	 *            The histogram mean (used to estimate p). Calculated if NaN.
	 * @param n
	 *            The n to evaluate
	 * @param zeroTruncated
	 *            True if the model should ignore n=0 (zero-truncated binomial)
	 * @return The best fit (n, p)
	 * @throws IllegalArgumentException
	 *             If any of the input data values are negative
	 * @throws IllegalArgumentException
	 *             If any fitting a zero truncated binomial and there are no values above zero
	 */
public PointValuePair fitBinomial(double[] histogram, double mean, int n, boolean zeroTruncated) {
    if (Double.isNaN(mean))
        mean = getMean(histogram);
    if (zeroTruncated && histogram[0] > 0) {
        log("Fitting zero-truncated histogram but there are zero values - Renormalising to ignore zero");
        double cumul = 0;
        for (int i = 1; i < histogram.length; i++) cumul += histogram[i];
        if (cumul == 0)
            throw new IllegalArgumentException("Fitting zero-truncated histogram but there are no non-zero values");
        histogram[0] = 0;
        for (int i = 1; i < histogram.length; i++) histogram[i] /= cumul;
    }
    int nFittedPoints = Math.min(histogram.length, n + 1) - ((zeroTruncated) ? 1 : 0);
    if (nFittedPoints < 1) {
        log("No points to fit (%d): Histogram.length = %d, n = %d, zero-truncated = %b", nFittedPoints, histogram.length, n, zeroTruncated);
        return null;
    }
    // The model is only fitting the probability p
    // For a binomial n*p = mean => p = mean/n
    double[] initialSolution = new double[] { FastMath.min(mean / n, 1) };
    // Create the function
    BinomialModelFunction function = new BinomialModelFunction(histogram, n, zeroTruncated);
    double[] lB = new double[1];
    double[] uB = new double[] { 1 };
    SimpleBounds bounds = new SimpleBounds(lB, uB);
    // Fit
    // CMAESOptimizer or BOBYQAOptimizer support bounds
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    int maxIterations = 2000;
    //Double.NEGATIVE_INFINITY;
    double stopFitness = 0;
    boolean isActiveCMA = true;
    int diagonalOnly = 0;
    int checkFeasableCount = 1;
    RandomGenerator random = new Well19937c();
    boolean generateStatistics = false;
    ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
    OptimizationData sigma = new CMAESOptimizer.Sigma(new double[] { (uB[0] - lB[0]) / 3 });
    OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(2))));
    try {
        PointValuePair solution = null;
        boolean noRefit = maximumLikelihood;
        if (n == 1 && zeroTruncated) {
            // No need to fit
            solution = new PointValuePair(new double[] { 1 }, 0);
            noRefit = true;
        } else {
            GoalType goalType = (maximumLikelihood) ? GoalType.MAXIMIZE : GoalType.MINIMIZE;
            // Iteratively fit
            CMAESOptimizer opt = new CMAESOptimizer(maxIterations, stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
            for (int iteration = 0; iteration <= fitRestarts; iteration++) {
                try {
                    // Start from the initial solution
                    PointValuePair result = opt.optimize(new InitialGuess(initialSolution), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    //		opt.getEvaluations());
                    if (solution == null || result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (TooManyEvaluationsException e) {
                } catch (TooManyIterationsException e) {
                }
                if (solution == null)
                    continue;
                try {
                    // Also restart from the current optimum
                    PointValuePair result = opt.optimize(new InitialGuess(solution.getPointRef()), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    //		opt.getEvaluations());
                    if (result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (TooManyEvaluationsException e) {
                } catch (TooManyIterationsException e) {
                }
            }
            if (solution == null)
                return null;
        }
        if (noRefit) {
            // Although we fit the log-likelihood, return the sum-of-squares to allow 
            // comparison across different n
            double p = solution.getPointRef()[0];
            double ss = 0;
            double[] obs = function.p;
            double[] exp = function.getP(p);
            for (int i = 0; i < obs.length; i++) ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            return new PointValuePair(solution.getPointRef(), ss);
        } else // We can do a LVM refit if the number of fitted points is more than 1
        if (nFittedPoints > 1) {
            // Improve SS fit with a gradient based LVM optimizer
            LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
            try {
                final BinomialModelFunctionGradient gradientFunction = new BinomialModelFunctionGradient(histogram, n, zeroTruncated);
                //@formatter:off
                LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(solution.getPointRef()).target(gradientFunction.p).weight(new DiagonalMatrix(gradientFunction.getWeights())).model(gradientFunction, new MultivariateMatrixFunction() {

                    public double[][] value(double[] point) throws IllegalArgumentException {
                        return gradientFunction.jacobian(point);
                    }
                }).build();
                //@formatter:on
                Optimum lvmSolution = optimizer.optimize(problem);
                // Check the pValue is valid since the LVM is not bounded.
                double p = lvmSolution.getPoint().getEntry(0);
                if (p <= 1 && p >= 0) {
                    // True if the weights are 1
                    double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
                    //	ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
                    if (ss < solution.getValue()) {
                        //		Utils.rounded(100 * (solution.getValue() - ss) / solution.getValue(), 4));
                        return new PointValuePair(lvmSolution.getPoint().toArray(), ss);
                    }
                }
            } catch (TooManyIterationsException e) {
                log("Failed to re-fit: Too many iterations: %s", e.getMessage());
            } catch (ConvergenceException e) {
                log("Failed to re-fit: %s", e.getMessage());
            } catch (Exception e) {
            // Ignore this ...
            }
        }
        return solution;
    } catch (Exception e) {
        log("Failed to fit Binomial distribution with N=%d : %s", n, e.getMessage());
    }
    return null;
}
Also used : InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) Well19937c(org.apache.commons.math3.random.Well19937c) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 8 with RandomGenerator

use of org.apache.commons.math3.random.RandomGenerator in project GDSC-SMLM by aherbert.

the class CMOSAnalysis method simulate.

private void simulate() {
    // Create the offset, variance and gain for each pixel
    int n = size * size;
    float[] pixelOffset = new float[n];
    float[] pixelVariance = new float[n];
    float[] pixelGain = new float[n];
    IJ.showStatus("Creating random per-pixel readout");
    long start = System.currentTimeMillis();
    RandomGenerator rg = new Well19937c();
    PoissonDistribution pd = new PoissonDistribution(rg, offset, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
    ExponentialDistribution ed = new ExponentialDistribution(rg, variance, ExponentialDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
    totalProgress = n;
    stepProgress = Utils.getProgressInterval(totalProgress);
    for (int i = 0; i < n; i++) {
        if (i % n == 0)
            IJ.showProgress(i, n);
        // Q. Should these be clipped to a sensible range?
        pixelOffset[i] = (float) pd.sample();
        pixelVariance[i] = (float) ed.sample();
        pixelGain[i] = (float) (gain + rg.nextGaussian() * gainSD);
    }
    IJ.showProgress(1);
    // Avoid all the file saves from updating the progress bar and status line
    Utils.setShowStatus(false);
    Utils.setShowProgress(false);
    JLabel statusLine = Utils.getStatusLine();
    progressBar = Utils.getProgressBar();
    // Save to the directory as a stack
    ImageStack simulationStack = new ImageStack(size, size);
    simulationStack.addSlice("Offset", pixelOffset);
    simulationStack.addSlice("Variance", pixelVariance);
    simulationStack.addSlice("Gain", pixelGain);
    simulationImp = new ImagePlus("PerPixel", simulationStack);
    // Only the info property is saved to the TIFF file
    simulationImp.setProperty("Info", String.format("Offset=%s; Variance=%s; Gain=%s +/- %s", Utils.rounded(offset), Utils.rounded(variance), Utils.rounded(gain), Utils.rounded(gainSD)));
    IJ.save(simulationImp, new File(directory, "perPixelSimulation.tif").getPath());
    // Create thread pool and workers
    ExecutorService executor = Executors.newFixedThreadPool(getThreads());
    TurboList<Future<?>> futures = new TurboList<Future<?>>(nThreads);
    // Simulate the zero exposure input.
    // Simulate 20 - 200 photon images.
    int[] photons = new int[] { 0, 20, 50, 100, 200 };
    totalProgress = photons.length * frames;
    stepProgress = Utils.getProgressInterval(totalProgress);
    progress = 0;
    progressBar.show(0);
    // For saving stacks
    int blockSize = 10;
    int nPerThread = (int) Math.ceil((double) frames / nThreads);
    // Convert to fit the block size
    nPerThread = (int) Math.ceil((double) nPerThread / blockSize) * blockSize;
    long seed = start;
    for (int p : photons) {
        statusLine.setText("Simulating " + Utils.pleural(p, "photon"));
        // Create the directory
        File out = new File(directory, String.format("photon%03d", p));
        if (!out.exists())
            out.mkdir();
        for (int from = 0; from < frames; ) {
            int to = Math.min(from + nPerThread, frames);
            futures.add(executor.submit(new SimulationWorker(seed++, out.getPath(), simulationStack, from, to, blockSize, p)));
            from = to;
        }
        // Wait for all to finish
        for (int t = futures.size(); t-- > 0; ) {
            try {
                // The future .get() method will block until completed
                futures.get(t).get();
            } catch (Exception e) {
                // This should not happen. 
                e.printStackTrace();
            }
        }
        futures.clear();
    }
    Utils.setShowStatus(true);
    Utils.setShowProgress(true);
    IJ.showProgress(1);
    executor.shutdown();
    Utils.log("Simulation time = " + Utils.timeToString(System.currentTimeMillis() - start));
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) TurboList(gdsc.core.utils.TurboList) ImageStack(ij.ImageStack) ExponentialDistribution(org.apache.commons.math3.distribution.ExponentialDistribution) JLabel(javax.swing.JLabel) Well19937c(org.apache.commons.math3.random.Well19937c) ImagePlus(ij.ImagePlus) PseudoRandomGenerator(gdsc.core.utils.PseudoRandomGenerator) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) File(java.io.File)

Example 9 with RandomGenerator

use of org.apache.commons.math3.random.RandomGenerator in project GDSC-SMLM by aherbert.

the class FisherInformationMatrixTest method createFisherInformationMatrix.

private FisherInformationMatrix createFisherInformationMatrix(int n, int k) {
    int maxx = 10;
    int size = maxx * maxx;
    RandomGenerator randomGenerator = new Well19937c(30051977);
    RandomDataGenerator rdg = new RandomDataGenerator(randomGenerator);
    // Use a real Gaussian function here to compute the Fisher information.
    // The matrix may be sensitive to the type of equation used.
    int npeaks = 1;
    while (1 + npeaks * 6 < n) npeaks++;
    Gaussian2DFunction f = GaussianFunctionFactory.create2D(npeaks, maxx, maxx, GaussianFunctionFactory.FIT_ELLIPTICAL, null);
    double[] a = new double[1 + npeaks * 6];
    a[Gaussian2DFunction.BACKGROUND] = rdg.nextUniform(1, 5);
    for (int i = 0, j = 0; i < npeaks; i++, j += 6) {
        a[j + Gaussian2DFunction.SIGNAL] = rdg.nextUniform(100, 300);
        a[j + Gaussian2DFunction.SHAPE] = rdg.nextUniform(-Math.PI, Math.PI);
        // Non-overlapping peaks otherwise the CRLB are poor
        a[j + Gaussian2DFunction.X_POSITION] = rdg.nextUniform(2 + i * 2, 4 + i * 2);
        a[j + Gaussian2DFunction.Y_POSITION] = rdg.nextUniform(2 + i * 2, 4 + i * 2);
        a[j + Gaussian2DFunction.X_SD] = rdg.nextUniform(1.5, 2);
        a[j + Gaussian2DFunction.Y_SD] = rdg.nextUniform(1.5, 2);
    }
    f.initialise(a);
    GradientCalculator c = GradientCalculatorFactory.newCalculator(a.length);
    double[][] I = c.fisherInformationMatrix(size, a, f);
    //System.out.printf("n=%d, k=%d, I=\n", n, k);
    //for (int i = 0; i < I.length; i++)
    //	System.out.println(Arrays.toString(I[i]));
    // Reduce to the desired size
    I = Arrays.copyOf(I, n);
    for (int i = 0; i < n; i++) I[i] = Arrays.copyOf(I[i], n);
    // Zero selected columns
    if (k > 0) {
        int[] zero = new RandomDataGenerator(randomGenerator).nextPermutation(n, k);
        for (int i : zero) {
            for (int j = 0; j < n; j++) {
                I[i][j] = I[j][i] = 0;
            }
        }
    }
    // Create matrix
    return new FisherInformationMatrix(I, 1e-3);
}
Also used : RandomDataGenerator(org.apache.commons.math3.random.RandomDataGenerator) Gaussian2DFunction(gdsc.smlm.function.gaussian.Gaussian2DFunction) GradientCalculator(gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator) Well19937c(org.apache.commons.math3.random.Well19937c) RandomGenerator(org.apache.commons.math3.random.RandomGenerator)

Example 10 with RandomGenerator

use of org.apache.commons.math3.random.RandomGenerator in project GDSC-SMLM by aherbert.

the class MaximumLikelihoodFitter method computeFit.

/*
	 * (non-Javadoc)
	 * 
	 * @see gdsc.smlm.fitting.nonlinear.BaseFunctionSolver#computeFit(double[], double[], double[], double[])
	 */
public FitStatus computeFit(double[] y, double[] y_fit, double[] a, double[] a_dev) {
    final int n = y.length;
    LikelihoodWrapper maximumLikelihoodFunction = createLikelihoodWrapper((NonLinearFunction) f, n, y, a);
    @SuppressWarnings("rawtypes") BaseOptimizer baseOptimiser = null;
    try {
        double[] startPoint = getInitialSolution(a);
        PointValuePair optimum = null;
        if (searchMethod == SearchMethod.POWELL || searchMethod == SearchMethod.POWELL_BOUNDED || searchMethod == SearchMethod.POWELL_ADAPTER) {
            // Non-differentiable version using Powell Optimiser
            // This is as per the method in Numerical Recipes 10.5 (Direction Set (Powell's) method)
            // I could extend the optimiser and implement bounds on the directions moved. However the mapping
            // adapter seems to work OK.
            final boolean basisConvergence = false;
            // Perhaps these thresholds should be tighter?
            // The default is to use the sqrt() of the overall tolerance
            //final double lineRel = FastMath.sqrt(relativeThreshold);
            //final double lineAbs = FastMath.sqrt(absoluteThreshold);
            //final double lineRel = relativeThreshold * 1e2;
            //final double lineAbs = absoluteThreshold * 1e2;
            // Since we are fitting only a small number of parameters then just use the same tolerance 
            // for each search direction
            final double lineRel = relativeThreshold;
            final double lineAbs = absoluteThreshold;
            CustomPowellOptimizer o = new CustomPowellOptimizer(relativeThreshold, absoluteThreshold, lineRel, lineAbs, null, basisConvergence);
            baseOptimiser = o;
            OptimizationData maxIterationData = null;
            if (getMaxIterations() > 0)
                maxIterationData = new MaxIter(getMaxIterations());
            if (searchMethod == SearchMethod.POWELL_ADAPTER) {
                // Try using the mapping adapter for a bounded Powell search
                MultivariateFunctionMappingAdapter adapter = new MultivariateFunctionMappingAdapter(new MultivariateLikelihood(maximumLikelihoodFunction), lower, upper);
                optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(adapter), GoalType.MINIMIZE, new InitialGuess(adapter.boundedToUnbounded(startPoint)));
                double[] solution = adapter.unboundedToBounded(optimum.getPointRef());
                optimum = new PointValuePair(solution, optimum.getValue());
            } else {
                if (powellFunction == null) {
                    // Python code by using the sqrt of the number of photons and background.
                    if (mapGaussian) {
                        Gaussian2DFunction gf = (Gaussian2DFunction) f;
                        // Re-map signal and background using the sqrt
                        int[] indices = gf.gradientIndices();
                        int[] map = new int[indices.length];
                        int count = 0;
                        // Background is always first
                        if (indices[0] == Gaussian2DFunction.BACKGROUND) {
                            map[count++] = 0;
                        }
                        // Look for the Signal in multiple peak 2D Gaussians
                        for (int i = 1; i < indices.length; i++) if (indices[i] % 6 == Gaussian2DFunction.SIGNAL) {
                            map[count++] = i;
                        }
                        if (count > 0) {
                            powellFunction = new MappedMultivariateLikelihood(maximumLikelihoodFunction, Arrays.copyOf(map, count));
                        }
                    }
                    if (powellFunction == null) {
                        powellFunction = new MultivariateLikelihood(maximumLikelihoodFunction);
                    }
                }
                // Update the maximum likelihood function in the Powell function wrapper
                powellFunction.fun = maximumLikelihoodFunction;
                OptimizationData positionChecker = null;
                // new org.apache.commons.math3.optim.PositionChecker(relativeThreshold, absoluteThreshold);
                SimpleBounds simpleBounds = null;
                if (powellFunction.isMapped()) {
                    MappedMultivariateLikelihood adapter = (MappedMultivariateLikelihood) powellFunction;
                    if (searchMethod == SearchMethod.POWELL_BOUNDED)
                        simpleBounds = new SimpleBounds(adapter.map(lower), adapter.map(upper));
                    optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(adapter.map(startPoint)), positionChecker, simpleBounds);
                    double[] solution = adapter.unmap(optimum.getPointRef());
                    optimum = new PointValuePair(solution, optimum.getValue());
                } else {
                    if (searchMethod == SearchMethod.POWELL_BOUNDED)
                        simpleBounds = new SimpleBounds(lower, upper);
                    optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(startPoint), positionChecker, simpleBounds);
                }
            }
        } else if (searchMethod == SearchMethod.BOBYQA) {
            // Differentiable approximation using Powell's BOBYQA algorithm.
            // This is slower than the Powell optimiser and requires a high number of evaluations.
            int numberOfInterpolationPoints = this.getNumberOfFittedParameters() + 2;
            BOBYQAOptimizer o = new BOBYQAOptimizer(numberOfInterpolationPoints);
            baseOptimiser = o;
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lower, upper));
        } else if (searchMethod == SearchMethod.CMAES) {
            // TODO - Understand why the CMAES optimiser does not fit very well on test data. It appears 
            // to converge too early and the likelihood scores are not as low as the other optimisers.
            // CMAESOptimiser based on Matlab code:
            // https://www.lri.fr/~hansen/cmaes.m
            // Take the defaults from the Matlab documentation
            //Double.NEGATIVE_INFINITY;
            double stopFitness = 0;
            boolean isActiveCMA = true;
            int diagonalOnly = 0;
            int checkFeasableCount = 1;
            RandomGenerator random = new Well19937c();
            boolean generateStatistics = false;
            // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
            double[] sigma = new double[lower.length];
            for (int i = 0; i < sigma.length; i++) sigma[i] = (upper[i] - lower[i]) / 3;
            int popSize = (int) (4 + Math.floor(3 * Math.log(sigma.length)));
            // The CMAES optimiser is random and restarting can overcome problems with quick convergence.
            // The Apache commons documentations states that convergence should occur between 30N and 300N^2
            // function evaluations
            final int n30 = FastMath.min(sigma.length * sigma.length * 30, getMaxEvaluations() / 2);
            evaluations = 0;
            OptimizationData[] data = new OptimizationData[] { new InitialGuess(startPoint), new CMAESOptimizer.PopulationSize(popSize), new MaxEval(getMaxEvaluations()), new CMAESOptimizer.Sigma(sigma), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new SimpleBounds(lower, upper) };
            // Iterate to prevent early convergence
            int repeat = 0;
            while (evaluations < n30) {
                if (repeat++ > 1) {
                    // Update the start point and population size
                    data[0] = new InitialGuess(optimum.getPointRef());
                    popSize *= 2;
                    data[1] = new CMAESOptimizer.PopulationSize(popSize);
                }
                CMAESOptimizer o = new CMAESOptimizer(getMaxIterations(), stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
                baseOptimiser = o;
                PointValuePair result = o.optimize(data);
                iterations += o.getIterations();
                evaluations += o.getEvaluations();
                //		o.getEvaluations(), totalEvaluations);
                if (optimum == null || result.getValue() < optimum.getValue()) {
                    optimum = result;
                }
            }
            // Prevent incrementing the iterations again
            baseOptimiser = null;
        } else if (searchMethod == SearchMethod.BFGS) {
            // BFGS can use an approximate line search minimisation where as Powell and conjugate gradient
            // methods require a more accurate line minimisation. The BFGS search does not do a full 
            // minimisation but takes appropriate steps in the direction of the current gradient.
            // Do not use the convergence checker on the value of the function. Use the convergence on the 
            // point coordinate and gradient
            //BFGSOptimizer o = new BFGSOptimizer(new SimpleValueChecker(rel, abs));
            BFGSOptimizer o = new BFGSOptimizer();
            baseOptimiser = o;
            // Configure maximum step length for each dimension using the bounds
            double[] stepLength = new double[lower.length];
            for (int i = 0; i < stepLength.length; i++) {
                stepLength[i] = (upper[i] - lower[i]) * 0.3333333;
                if (stepLength[i] <= 0)
                    stepLength[i] = Double.POSITIVE_INFINITY;
            }
            // The GoalType is always minimise so no need to pass this in
            OptimizationData positionChecker = null;
            //new org.apache.commons.math3.optim.PositionChecker(relativeThreshold, absoluteThreshold);
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BFGSOptimizer.GradientTolerance(relativeThreshold), positionChecker, new BFGSOptimizer.StepLength(stepLength));
        } else {
            // The line search algorithm often fails. This is due to searching into a region where the 
            // function evaluates to a negative so has been clipped. This means the upper bound of the line
            // cannot be found.
            // Note that running it on an easy problem (200 photons with fixed fitting (no background)) the algorithm
            // does sometimes produces results better than the Powell algorithm but it is slower.
            BoundedNonLinearConjugateGradientOptimizer o = new BoundedNonLinearConjugateGradientOptimizer((searchMethod == SearchMethod.CONJUGATE_GRADIENT_FR) ? Formula.FLETCHER_REEVES : Formula.POLAK_RIBIERE, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
            baseOptimiser = o;
            // Note: The gradients may become unstable at the edge of the bounds. Or they will not change 
            // direction if the true solution is on the bounds since the gradient will always continue 
            // towards the bounds. This is key to the conjugate gradient method. It searches along a vector 
            // until the direction of the gradient is in the opposite direction (using dot products, i.e. 
            // cosine of angle between them)
            // NR 10.7 states there is no advantage of the variable metric DFP or BFGS methods over
            // conjugate gradient methods. So I will try these first.
            // Try this:
            // Adapt the conjugate gradient optimiser to use the gradient to pick the search direction
            // and then for the line minimisation. However if the function is out of bounds then clip the 
            // variables at the bounds and continue. 
            // If the current point is at the bounds and the gradient is to continue out of bounds then 
            // clip the gradient too.
            // Or: just use the gradient for the search direction then use the line minimisation/rest
            // as per the Powell optimiser. The bounds should limit the search.
            // I tried a Bounded conjugate gradient optimiser with clipped variables:
            // This sometimes works. However when the variables go a long way out of the expected range the gradients
            // can have vastly different magnitudes. This results in the algorithm stalling since the gradients
            // can be close to zero and the some of the parameters are no longer adjusted.
            // Perhaps this can be looked for and the algorithm then gives up and resorts to a Powell optimiser from 
            // the current point.
            // Changed the bracketing step to very small (default is 1, changed to 0.001). This improves the 
            // performance. The gradient direction is very sensitive to small changes in the coordinates so a 
            // tighter bracketing of the line search helps.
            // Tried using a non-gradient method for the line search copied from the Powell optimiser:
            // This also works when the bracketing step is small but the number of iterations is higher.
            // 24.10.2014: I have tried to get conjugate gradient to work but the gradient function 
            // must not behave suitably for the optimiser. In the current state both methods of using a 
            // Bounded Conjugate Gradient Optimiser perform poorly relative to other optimisers:
            // Simulated : n=1000, signal=200, x=0.53, y=0.47
            // LVM : n=1000, signal=171, x=0.537, y=0.471 (1.003s)
            // Powell : n=1000, signal=187, x=0.537, y=0.48 (1.238s)
            // Gradient based PR (constrained): n=858, signal=161, x=0.533, y=0.474 (2.54s)
            // Gradient based PR (bounded): n=948, signal=161, x=0.533, y=0.473 (2.67s)
            // Non-gradient based : n=1000, signal=151.47, x=0.535, y=0.474 (1.626s)
            // The conjugate optimisers are slower, under predict the signal by the most and in the case of 
            // the gradient based optimiser, fail to converge on some problems. This is worse when constrained
            // fitting is used and not tightly bounded fitting.
            // I will leave the code in as an option but would not recommend using it. I may remove it in the 
            // future.
            // Note: It is strange that the non-gradient based line minimisation is more successful.
            // It may be that the gradient function is not accurate (due to round off error) or that it is
            // simply wrong when far from the optimum. My JUnit tests only evaluate the function within the 
            // expected range of the answer.
            // Note the default step size on the Powell optimiser is 1 but the initial directions are unit vectors.
            // So our bracketing step should be a minimum of 1 / average length of the first gradient vector to prevent
            // the first step being too large when bracketing.
            final double[] gradient = new double[startPoint.length];
            maximumLikelihoodFunction.likelihood(startPoint, gradient);
            double l = 0;
            for (double d : gradient) l += d * d;
            final double bracketingStep = FastMath.min(0.001, ((l > 1) ? 1.0 / l : 1));
            //System.out.printf("Bracketing step = %f (length=%f)\n", bracketingStep, l);
            o.setUseGradientLineSearch(gradientLineMinimisation);
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BoundedNonLinearConjugateGradientOptimizer.BracketingStep(bracketingStep));
        //maximumLikelihoodFunction.value(solution, gradient);
        //System.out.printf("Iter = %d, %g @ %s : %s\n", iterations, ll, Arrays.toString(solution),
        //		Arrays.toString(gradient));
        }
        final double[] solution = optimum.getPointRef();
        setSolution(a, solution);
        if (a_dev != null) {
            // Assume the Maximum Likelihood estimator returns the optimum fit (achieves the Cramer Roa
            // lower bounds) and so the covariance can be obtained from the Fisher Information Matrix.
            FisherInformationMatrix m = new FisherInformationMatrix(maximumLikelihoodFunction.fisherInformation(a));
            setDeviations(a_dev, m.crlb(true));
        }
        // Reverse negative log likelihood for maximum likelihood score
        value = -optimum.getValue();
    } catch (TooManyIterationsException e) {
        //e.printStackTrace();
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (TooManyEvaluationsException e) {
        //e.printStackTrace();
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (ConvergenceException e) {
        //System.out.printf("Singular non linear model = %s\n", e.getMessage());
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (BFGSOptimizer.LineSearchRoundoffException e) {
        //e.printStackTrace();
        return FitStatus.FAILED_TO_CONVERGE;
    } catch (Exception e) {
        //System.out.printf("Unknown error = %s\n", e.getMessage());
        e.printStackTrace();
        return FitStatus.UNKNOWN;
    } finally {
        if (baseOptimiser != null) {
            iterations += baseOptimiser.getIterations();
            evaluations += baseOptimiser.getEvaluations();
        }
    }
    // Check this as likelihood functions can go wrong
    if (Double.isInfinite(value) || Double.isNaN(value))
        return FitStatus.INVALID_LIKELIHOOD;
    return FitStatus.OK;
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) BOBYQAOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) Well19937c(org.apache.commons.math3.random.Well19937c) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) BFGSOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.gradient.BFGSOptimizer) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Gaussian2DFunction(gdsc.smlm.function.gaussian.Gaussian2DFunction) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) BoundedNonLinearConjugateGradientOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) BaseOptimizer(org.apache.commons.math3.optim.BaseOptimizer) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) FisherInformationMatrix(gdsc.smlm.fitting.FisherInformationMatrix) PoissonGammaGaussianLikelihoodWrapper(gdsc.smlm.function.PoissonGammaGaussianLikelihoodWrapper) PoissonGaussianLikelihoodWrapper(gdsc.smlm.function.PoissonGaussianLikelihoodWrapper) PoissonLikelihoodWrapper(gdsc.smlm.function.PoissonLikelihoodWrapper) LikelihoodWrapper(gdsc.smlm.function.LikelihoodWrapper) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) MultivariateFunctionMappingAdapter(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter) OptimizationData(org.apache.commons.math3.optim.OptimizationData) CustomPowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer) MaxIter(org.apache.commons.math3.optim.MaxIter)

Aggregations

RandomGenerator (org.apache.commons.math3.random.RandomGenerator)70 Well19937c (org.apache.commons.math3.random.Well19937c)25 Random (java.util.Random)20 Test (org.testng.annotations.Test)18 RandomGeneratorFactory (org.apache.commons.math3.random.RandomGeneratorFactory)16 Assert (org.testng.Assert)16 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)14 Collectors (java.util.stream.Collectors)12 IntStream (java.util.stream.IntStream)12 Arrays (java.util.Arrays)10 List (java.util.List)10 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)10 Test (org.junit.Test)10 ArrayList (java.util.ArrayList)9 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)8 ModeledSegment (org.broadinstitute.hellbender.tools.exome.ModeledSegment)8 AllelicCountCollection (org.broadinstitute.hellbender.tools.exome.alleliccount.AllelicCountCollection)8 java.util (java.util)6 GammaDistribution (org.apache.commons.math3.distribution.GammaDistribution)6 BinomialDistribution (org.apache.commons.math3.distribution.BinomialDistribution)5