Search in sources :

Example 1 with Weight

use of org.apache.commons.math3.optim.nonlinear.vector.Weight in project GDSC-SMLM by aherbert.

the class PSFCreator method run.

/*
	 * (non-Javadoc)
	 * 
	 * @see ij.plugin.filter.PlugInFilter#run(ij.process.ImageProcessor)
	 */
public void run(ImageProcessor ip) {
    loadConfiguration();
    BasePoint[] spots = getSpots();
    if (spots.length == 0) {
        IJ.error(TITLE, "No spots without neighbours within " + (boxRadius * 2) + "px");
        return;
    }
    ImageStack stack = getImageStack();
    final int width = imp.getWidth();
    final int height = imp.getHeight();
    final int currentSlice = imp.getSlice();
    // Adjust settings for a single maxima
    config.setIncludeNeighbours(false);
    fitConfig.setDuplicateDistance(0);
    ArrayList<double[]> centres = new ArrayList<double[]>(spots.length);
    int iterations = 1;
    LoessInterpolator loess = new LoessInterpolator(smoothing, iterations);
    // TODO - The fitting routine may not produce many points. In this instance the LOESS interpolator
    // fails to smooth the data very well. A higher bandwidth helps this but perhaps 
    // try a different smoothing method.
    // For each spot
    Utils.log(TITLE + ": " + imp.getTitle());
    Utils.log("Finding spot locations...");
    Utils.log("  %d spot%s without neighbours within %dpx", spots.length, ((spots.length == 1) ? "" : "s"), (boxRadius * 2));
    StoredDataStatistics averageSd = new StoredDataStatistics();
    StoredDataStatistics averageA = new StoredDataStatistics();
    Statistics averageRange = new Statistics();
    MemoryPeakResults allResults = new MemoryPeakResults();
    allResults.setName(TITLE);
    allResults.setBounds(new Rectangle(0, 0, width, height));
    MemoryPeakResults.addResults(allResults);
    for (int n = 1; n <= spots.length; n++) {
        BasePoint spot = spots[n - 1];
        final int x = (int) spot.getX();
        final int y = (int) spot.getY();
        MemoryPeakResults results = fitSpot(stack, width, height, x, y);
        allResults.addAllf(results.getResults());
        if (results.size() < 5) {
            Utils.log("  Spot %d: Not enough fit results %d", n, results.size());
            continue;
        }
        // Get the results for the spot centre and width
        double[] z = new double[results.size()];
        double[] xCoord = new double[z.length];
        double[] yCoord = new double[z.length];
        double[] sd = new double[z.length];
        double[] a = new double[z.length];
        int i = 0;
        for (PeakResult peak : results.getResults()) {
            z[i] = peak.getFrame();
            xCoord[i] = peak.getXPosition() - x;
            yCoord[i] = peak.getYPosition() - y;
            sd[i] = FastMath.max(peak.getXSD(), peak.getYSD());
            a[i] = peak.getAmplitude();
            i++;
        }
        // Smooth the amplitude plot
        double[] smoothA = loess.smooth(z, a);
        // Find the maximum amplitude
        int maximumIndex = findMaximumIndex(smoothA);
        // Find the range at a fraction of the max. This is smoothed to find the X/Y centre
        int start = 0, stop = smoothA.length - 1;
        double limit = smoothA[maximumIndex] * amplitudeFraction;
        for (int j = 0; j < smoothA.length; j++) {
            if (smoothA[j] > limit) {
                start = j;
                break;
            }
        }
        for (int j = smoothA.length; j-- > 0; ) {
            if (smoothA[j] > limit) {
                stop = j;
                break;
            }
        }
        averageRange.add(stop - start + 1);
        // Extract xy centre coords and smooth
        double[] smoothX = new double[stop - start + 1];
        double[] smoothY = new double[smoothX.length];
        double[] smoothSd = new double[smoothX.length];
        double[] newZ = new double[smoothX.length];
        for (int j = start, k = 0; j <= stop; j++, k++) {
            smoothX[k] = xCoord[j];
            smoothY[k] = yCoord[j];
            smoothSd[k] = sd[j];
            newZ[k] = z[j];
        }
        smoothX = loess.smooth(newZ, smoothX);
        smoothY = loess.smooth(newZ, smoothY);
        smoothSd = loess.smooth(newZ, smoothSd);
        // Since the amplitude is not very consistent move from this peak to the 
        // lowest width which is the in-focus spot.
        maximumIndex = findMinimumIndex(smoothSd, maximumIndex - start);
        // Find the centre at the amplitude peak
        double cx = smoothX[maximumIndex] + x;
        double cy = smoothY[maximumIndex] + y;
        int cz = (int) newZ[maximumIndex];
        double csd = smoothSd[maximumIndex];
        double ca = smoothA[maximumIndex + start];
        // The average should weight the SD using the signal for each spot
        averageSd.add(smoothSd[maximumIndex]);
        averageA.add(ca);
        if (ignoreSpot(n, z, a, smoothA, xCoord, yCoord, sd, newZ, smoothX, smoothY, smoothSd, cx, cy, cz, csd)) {
            Utils.log("  Spot %d was ignored", n);
            continue;
        }
        // Store result - it may have been moved interactively
        maximumIndex += this.slice - cz;
        cz = (int) newZ[maximumIndex];
        csd = smoothSd[maximumIndex];
        ca = smoothA[maximumIndex + start];
        Utils.log("  Spot %d => x=%.2f, y=%.2f, z=%d, sd=%.2f, A=%.2f\n", n, cx, cy, cz, csd, ca);
        centres.add(new double[] { cx, cy, cz, csd, n });
    }
    if (interactiveMode) {
        imp.setSlice(currentSlice);
        imp.setOverlay(null);
        // Hide the amplitude and spot plots
        Utils.hide(TITLE_AMPLITUDE);
        Utils.hide(TITLE_PSF_PARAMETERS);
    }
    if (centres.isEmpty()) {
        String msg = "No suitable spots could be identified centres";
        Utils.log(msg);
        IJ.error(TITLE, msg);
        return;
    }
    // Find the limits of the z-centre
    int minz = (int) centres.get(0)[2];
    int maxz = minz;
    for (double[] centre : centres) {
        if (minz > centre[2])
            minz = (int) centre[2];
        else if (maxz < centre[2])
            maxz = (int) centre[2];
    }
    IJ.showStatus("Creating PSF image");
    // Create a stack that can hold all the data.
    ImageStack psf = createStack(stack, minz, maxz, magnification);
    // For each spot
    Statistics stats = new Statistics();
    boolean ok = true;
    for (int i = 0; ok && i < centres.size(); i++) {
        double progress = (double) i / centres.size();
        final double increment = 1.0 / (stack.getSize() * centres.size());
        IJ.showProgress(progress);
        double[] centre = centres.get(i);
        // Extract the spot
        float[][] spot = new float[stack.getSize()][];
        Rectangle regionBounds = null;
        for (int slice = 1; slice <= stack.getSize(); slice++) {
            ImageExtractor ie = new ImageExtractor((float[]) stack.getPixels(slice), width, height);
            if (regionBounds == null)
                regionBounds = ie.getBoxRegionBounds((int) centre[0], (int) centre[1], boxRadius);
            spot[slice - 1] = ie.crop(regionBounds);
        }
        int n = (int) centre[4];
        final float b = getBackground(n, spot);
        if (!subtractBackgroundAndWindow(spot, b, regionBounds.width, regionBounds.height, centre, loess)) {
            Utils.log("  Spot %d was ignored", n);
            continue;
        }
        stats.add(b);
        // Adjust the centre using the crop
        centre[0] -= regionBounds.x;
        centre[1] -= regionBounds.y;
        // This takes a long time so this should track progress
        ok = addToPSF(maxz, magnification, psf, centre, spot, regionBounds, progress, increment, centreEachSlice);
    }
    if (interactiveMode) {
        Utils.hide(TITLE_INTENSITY);
    }
    IJ.showProgress(1);
    if (threadPool != null) {
        threadPool.shutdownNow();
        threadPool = null;
    }
    if (!ok || stats.getN() == 0)
        return;
    final double avSd = getAverage(averageSd, averageA, 2);
    Utils.log("  Average background = %.2f, Av. SD = %s px", stats.getMean(), Utils.rounded(avSd, 4));
    normalise(psf, maxz, avSd * magnification, false);
    IJ.showProgress(1);
    psfImp = Utils.display("PSF", psf);
    psfImp.setSlice(maxz);
    psfImp.resetDisplayRange();
    psfImp.updateAndDraw();
    double[][] fitCom = new double[2][psf.getSize()];
    Arrays.fill(fitCom[0], Double.NaN);
    Arrays.fill(fitCom[1], Double.NaN);
    double fittedSd = fitPSF(psf, loess, maxz, averageRange.getMean(), fitCom);
    // Compute the drift in the PSF:
    // - Use fitted centre if available; otherwise find CoM for each frame
    // - express relative to the average centre
    double[][] com = calculateCentreOfMass(psf, fitCom, nmPerPixel / magnification);
    double[] slice = Utils.newArray(psf.getSize(), 1, 1.0);
    String title = TITLE + " CoM Drift";
    Plot2 plot = new Plot2(title, "Slice", "Drift (nm)");
    plot.addLabel(0, 0, "Red = X; Blue = Y");
    //double[] limitsX = Maths.limits(com[0]);
    //double[] limitsY = Maths.limits(com[1]);
    double[] limitsX = getLimits(com[0]);
    double[] limitsY = getLimits(com[1]);
    plot.setLimits(1, psf.getSize(), Math.min(limitsX[0], limitsY[0]), Math.max(limitsX[1], limitsY[1]));
    plot.setColor(Color.red);
    plot.addPoints(slice, com[0], Plot.DOT);
    plot.addPoints(slice, loess.smooth(slice, com[0]), Plot.LINE);
    plot.setColor(Color.blue);
    plot.addPoints(slice, com[1], Plot.DOT);
    plot.addPoints(slice, loess.smooth(slice, com[1]), Plot.LINE);
    Utils.display(title, plot);
    // TODO - Redraw the PSF with drift correction applied. 
    // This means that the final image should have no drift.
    // This is relevant when combining PSF images. It doesn't matter too much for simulations 
    // unless the drift is large.
    // Add Image properties containing the PSF details
    final double fwhm = getFWHM(psf, maxz);
    psfImp.setProperty("Info", XmlUtils.toXML(new PSFSettings(maxz, nmPerPixel / magnification, nmPerSlice, stats.getN(), fwhm, createNote())));
    Utils.log("%s : z-centre = %d, nm/Pixel = %s, nm/Slice = %s, %d images, PSF SD = %s nm, FWHM = %s px\n", psfImp.getTitle(), maxz, Utils.rounded(nmPerPixel / magnification, 3), Utils.rounded(nmPerSlice, 3), stats.getN(), Utils.rounded(fittedSd * nmPerPixel, 4), Utils.rounded(fwhm));
    createInteractivePlots(psf, maxz, nmPerPixel / magnification, fittedSd * nmPerPixel);
    IJ.showStatus("");
}
Also used : ImageStack(ij.ImageStack) BasePoint(gdsc.core.match.BasePoint) ArrayList(java.util.ArrayList) StoredDataStatistics(gdsc.core.utils.StoredDataStatistics) Rectangle(java.awt.Rectangle) Plot2(ij.gui.Plot2) Statistics(gdsc.core.utils.Statistics) StoredDataStatistics(gdsc.core.utils.StoredDataStatistics) DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) Point(java.awt.Point) BasePoint(gdsc.core.match.BasePoint) PeakResult(gdsc.smlm.results.PeakResult) LoessInterpolator(org.apache.commons.math3.analysis.interpolation.LoessInterpolator) MemoryPeakResults(gdsc.smlm.results.MemoryPeakResults) ImageExtractor(gdsc.core.utils.ImageExtractor) PSFSettings(gdsc.smlm.ij.settings.PSFSettings)

Example 2 with Weight

use of org.apache.commons.math3.optim.nonlinear.vector.Weight in project GDSC-SMLM by aherbert.

the class FIRE method createImages.

/**
	 * Creates the images to use for the FIRE calculation. This must be called after
	 * {@link #initialise(MemoryPeakResults, MemoryPeakResults)}.
	 *
	 * @param fourierImageScale
	 *            the fourier image scale (set to zero to auto compute)
	 * @param imageSize
	 *            the image size
	 * @param useSignal
	 *            Use the localisation signal to weight the intensity. The default uses a value of 1 per localisation.
	 * @return the fire images
	 */
public FireImages createImages(double fourierImageScale, int imageSize, boolean useSignal) {
    if (results == null)
        return null;
    final SignalProvider signalProvider = (useSignal && (results.getHead().getSignal() > 0)) ? new PeakSignalProvider() : new FixedSignalProvider();
    // Draw images using the existing IJ routines.
    Rectangle bounds = new Rectangle(0, 0, (int) Math.ceil(dataBounds.getWidth()), (int) Math.ceil(dataBounds.getHeight()));
    boolean weighted = true;
    boolean equalised = false;
    double imageScale;
    if (fourierImageScale <= 0) {
        double size = FastMath.max(bounds.width, bounds.height);
        if (size <= 0)
            size = 1;
        imageScale = imageSize / size;
    } else
        imageScale = fourierImageScale;
    IJImagePeakResults image1 = ImagePeakResultsFactory.createPeakResultsImage(ResultsImage.NONE, weighted, equalised, "IP1", bounds, 1, 1, imageScale, 0, ResultsMode.ADD);
    image1.setDisplayImage(false);
    image1.begin();
    IJImagePeakResults image2 = ImagePeakResultsFactory.createPeakResultsImage(ResultsImage.NONE, weighted, equalised, "IP2", bounds, 1, 1, imageScale, 0, ResultsMode.ADD);
    image2.setDisplayImage(false);
    image2.begin();
    float minx = (float) dataBounds.getX();
    float miny = (float) dataBounds.getY();
    if (this.results2 != null) {
        // Two image comparison
        for (PeakResult p : results) {
            float x = p.getXPosition() - minx;
            float y = p.getYPosition() - miny;
            image1.add(x, y, signalProvider.getSignal(p));
        }
        for (PeakResult p : results2) {
            float x = p.getXPosition() - minx;
            float y = p.getYPosition() - miny;
            image2.add(x, y, signalProvider.getSignal(p));
        }
    } else {
        // Block sampling.
        // Ensure we have at least 2 even sized blocks.
        int blockSize = Math.min(results.size() / 2, Math.max(1, FIRE.blockSize));
        int nBlocks = (int) Math.ceil((double) results.size() / blockSize);
        while (nBlocks <= 1 && blockSize > 1) {
            blockSize /= 2;
            nBlocks = (int) Math.ceil((double) results.size() / blockSize);
        }
        if (nBlocks <= 1)
            // This should not happen since the results should contain at least 2 localisations
            return null;
        if (blockSize != FIRE.blockSize)
            IJ.log(TITLE + " Warning: Changed block size to " + blockSize);
        int i = 0;
        int block = 0;
        PeakResult[][] blocks = new PeakResult[nBlocks][blockSize];
        for (PeakResult p : results) {
            if (i == blockSize) {
                block++;
                i = 0;
            }
            blocks[block][i++] = p;
        }
        // Truncate last block
        blocks[block] = Arrays.copyOf(blocks[block], i);
        final int[] indices = Utils.newArray(nBlocks, 0, 1);
        if (randomSplit)
            MathArrays.shuffle(indices);
        for (int index : indices) {
            // Split alternating so just rotate
            IJImagePeakResults image = image1;
            image1 = image2;
            image2 = image;
            for (PeakResult p : blocks[index]) {
                float x = p.getXPosition() - minx;
                float y = p.getYPosition() - miny;
                image.add(x, y, signalProvider.getSignal(p));
            }
        }
    }
    image1.end();
    ImageProcessor ip1 = image1.getImagePlus().getProcessor();
    image2.end();
    ImageProcessor ip2 = image2.getImagePlus().getProcessor();
    if (maxPerBin > 0 && signalProvider instanceof FixedSignalProvider) {
        // We can eliminate over-sampled pixels
        for (int i = ip1.getPixelCount(); i-- > 0; ) {
            if (ip1.getf(i) > maxPerBin)
                ip1.setf(i, maxPerBin);
            if (ip2.getf(i) > maxPerBin)
                ip2.setf(i, maxPerBin);
        }
    }
    return new FireImages(ip1, ip2, nmPerPixel / imageScale);
}
Also used : Rectangle(java.awt.Rectangle) IJImagePeakResults(gdsc.smlm.ij.results.IJImagePeakResults) PeakResult(gdsc.smlm.results.PeakResult) WeightedObservedPoint(org.apache.commons.math3.fitting.WeightedObservedPoint) ImageProcessor(ij.process.ImageProcessor)

Example 3 with Weight

use of org.apache.commons.math3.optim.nonlinear.vector.Weight in project GDSC-SMLM by aherbert.

the class JumpDistanceAnalysis method doFitJumpDistanceHistogram.

/**
	 * Fit the jump distance histogram using a cumulative sum with the given number of species.
	 * <p>
	 * Results are sorted by the diffusion coefficient ascending.
	 * 
	 * @param jdHistogram
	 *            The cumulative jump distance histogram. X-axis is um^2, Y-axis is cumulative probability. Must be
	 *            monototic ascending.
	 * @param estimatedD
	 *            The estimated diffusion coefficient
	 * @param n
	 *            The number of species in the mixed population
	 * @return Array containing: { D (um^2), Fractions }. Can be null if no fit was made.
	 */
private double[][] doFitJumpDistanceHistogram(double[][] jdHistogram, double estimatedD, int n) {
    calibrated = isCalibrated();
    if (n == 1) {
        // Fit using a single population model
        LevenbergMarquardtOptimizer lvmOptimizer = new LevenbergMarquardtOptimizer();
        try {
            final JumpDistanceCumulFunction function = new JumpDistanceCumulFunction(jdHistogram[0], jdHistogram[1], estimatedD);
            //@formatter:off
            LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(function.guess()).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

                public double[][] value(double[] point) throws IllegalArgumentException {
                    return function.jacobian(point);
                }
            }).build();
            //@formatter:on
            Optimum lvmSolution = lvmOptimizer.optimize(problem);
            double[] fitParams = lvmSolution.getPoint().toArray();
            // True for an unweighted fit
            ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
            //ss = calculateSumOfSquares(function.getY(), function.value(fitParams));
            lastIC = ic = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.x.length, 1);
            double[] coefficients = fitParams;
            double[] fractions = new double[] { 1 };
            logger.info("Fit Jump distance (N=1) : %s, SS = %s, IC = %s (%d evaluations)", formatD(fitParams[0]), Maths.rounded(ss, 4), Maths.rounded(ic, 4), lvmSolution.getEvaluations());
            return new double[][] { coefficients, fractions };
        } catch (TooManyIterationsException e) {
            logger.info("LVM optimiser failed to fit (N=1) : Too many iterations : %s", e.getMessage());
        } catch (ConvergenceException e) {
            logger.info("LVM optimiser failed to fit (N=1) : %s", e.getMessage());
        }
    }
    // Uses a weighted sum of n exponential functions, each function models a fraction of the particles.
    // An LVM fit cannot restrict the parameters so the fractions do not go below zero.
    // Use the CustomPowell/CMEASOptimizer which supports bounded fitting.
    MixedJumpDistanceCumulFunctionMultivariate function = new MixedJumpDistanceCumulFunctionMultivariate(jdHistogram[0], jdHistogram[1], estimatedD, n);
    double[] lB = function.getLowerBounds();
    int evaluations = 0;
    PointValuePair constrainedSolution = null;
    MaxEval maxEval = new MaxEval(20000);
    CustomPowellOptimizer powellOptimizer = createCustomPowellOptimizer();
    try {
        // The Powell algorithm can use more general bounds: 0 - Infinity
        constrainedSolution = powellOptimizer.optimize(maxEval, new ObjectiveFunction(function), new InitialGuess(function.guess()), new SimpleBounds(lB, function.getUpperBounds(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY)), new CustomPowellOptimizer.BasisStep(function.step()), GoalType.MINIMIZE);
        evaluations = powellOptimizer.getEvaluations();
        logger.debug("Powell optimiser fit (N=%d) : SS = %f (%d evaluations)", n, constrainedSolution.getValue(), evaluations);
    } catch (TooManyEvaluationsException e) {
        logger.info("Powell optimiser failed to fit (N=%d) : Too many evaluations (%d)", n, powellOptimizer.getEvaluations());
    } catch (TooManyIterationsException e) {
        logger.info("Powell optimiser failed to fit (N=%d) : Too many iterations (%d)", n, powellOptimizer.getIterations());
    } catch (ConvergenceException e) {
        logger.info("Powell optimiser failed to fit (N=%d) : %s", n, e.getMessage());
    }
    if (constrainedSolution == null) {
        logger.info("Trying CMAES optimiser with restarts ...");
        double[] uB = function.getUpperBounds();
        SimpleBounds bounds = new SimpleBounds(lB, uB);
        // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
        double[] s = new double[lB.length];
        for (int i = 0; i < s.length; i++) s[i] = (uB[i] - lB[i]) / 3;
        OptimizationData sigma = new CMAESOptimizer.Sigma(s);
        OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(function.x.length))));
        // Iterate this for stability in the initial guess
        CMAESOptimizer cmaesOptimizer = createCMAESOptimizer();
        for (int i = 0; i <= fitRestarts; i++) {
            // Try from the initial guess
            try {
                PointValuePair solution = cmaesOptimizer.optimize(new InitialGuess(function.guess()), new ObjectiveFunction(function), GoalType.MINIMIZE, bounds, sigma, popSize, maxEval);
                if (constrainedSolution == null || solution.getValue() < constrainedSolution.getValue()) {
                    evaluations = cmaesOptimizer.getEvaluations();
                    constrainedSolution = solution;
                    logger.debug("CMAES optimiser [%da] fit (N=%d) : SS = %f (%d evaluations)", i, n, solution.getValue(), evaluations);
                }
            } catch (TooManyEvaluationsException e) {
            }
            if (constrainedSolution == null)
                continue;
            // Try from the current optimum
            try {
                PointValuePair solution = cmaesOptimizer.optimize(new InitialGuess(constrainedSolution.getPointRef()), new ObjectiveFunction(function), GoalType.MINIMIZE, bounds, sigma, popSize, maxEval);
                if (solution.getValue() < constrainedSolution.getValue()) {
                    evaluations = cmaesOptimizer.getEvaluations();
                    constrainedSolution = solution;
                    logger.debug("CMAES optimiser [%db] fit (N=%d) : SS = %f (%d evaluations)", i, n, solution.getValue(), evaluations);
                }
            } catch (TooManyEvaluationsException e) {
            }
        }
        if (constrainedSolution != null) {
            // Re-optimise with Powell?
            try {
                PointValuePair solution = powellOptimizer.optimize(maxEval, new ObjectiveFunction(function), new InitialGuess(constrainedSolution.getPointRef()), new SimpleBounds(lB, function.getUpperBounds(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY)), new CustomPowellOptimizer.BasisStep(function.step()), GoalType.MINIMIZE);
                if (solution.getValue() < constrainedSolution.getValue()) {
                    evaluations = cmaesOptimizer.getEvaluations();
                    constrainedSolution = solution;
                    logger.info("Powell optimiser re-fit (N=%d) : SS = %f (%d evaluations)", n, constrainedSolution.getValue(), evaluations);
                }
            } catch (TooManyEvaluationsException e) {
            } catch (TooManyIterationsException e) {
            } catch (ConvergenceException e) {
            }
        }
    }
    if (constrainedSolution == null) {
        logger.info("Failed to fit N=%d", n);
        return null;
    }
    double[] fitParams = constrainedSolution.getPointRef();
    ss = constrainedSolution.getValue();
    // TODO - Try a bounded BFGS optimiser
    // Try and improve using a LVM fit
    final MixedJumpDistanceCumulFunctionGradient functionGradient = new MixedJumpDistanceCumulFunctionGradient(jdHistogram[0], jdHistogram[1], estimatedD, n);
    Optimum lvmSolution;
    LevenbergMarquardtOptimizer lvmOptimizer = new LevenbergMarquardtOptimizer();
    try {
        //@formatter:off
        LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(fitParams).target(functionGradient.getY()).weight(new DiagonalMatrix(functionGradient.getWeights())).model(functionGradient, new MultivariateMatrixFunction() {

            public double[][] value(double[] point) throws IllegalArgumentException {
                return functionGradient.jacobian(point);
            }
        }).build();
        //@formatter:on
        lvmSolution = lvmOptimizer.optimize(problem);
        // True for an unweighted fit
        double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
        // All fitted parameters must be above zero
        if (ss < this.ss && Maths.min(lvmSolution.getPoint().toArray()) > 0) {
            logger.info("  Re-fitting improved the SS from %s to %s (-%s%%)", Maths.rounded(this.ss, 4), Maths.rounded(ss, 4), Maths.rounded(100 * (this.ss - ss) / this.ss, 4));
            fitParams = lvmSolution.getPoint().toArray();
            this.ss = ss;
            evaluations += lvmSolution.getEvaluations();
        }
    } catch (TooManyIterationsException e) {
        logger.error("Failed to re-fit : Too many iterations : %s", e.getMessage());
    } catch (ConvergenceException e) {
        logger.error("Failed to re-fit : %s", e.getMessage());
    }
    // Since the fractions must sum to one we subtract 1 degree of freedom from the number of parameters
    ic = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.x.length, fitParams.length - 1);
    double[] d = new double[n];
    double[] f = new double[n];
    double sum = 0;
    for (int i = 0; i < d.length; i++) {
        f[i] = fitParams[i * 2];
        sum += f[i];
        d[i] = fitParams[i * 2 + 1];
    }
    for (int i = 0; i < f.length; i++) f[i] /= sum;
    // Sort by coefficient size
    sort(d, f);
    double[] coefficients = d;
    double[] fractions = f;
    logger.info("Fit Jump distance (N=%d) : %s (%s), SS = %s, IC = %s (%d evaluations)", n, formatD(d), format(f), Maths.rounded(ss, 4), Maths.rounded(ic, 4), evaluations);
    if (isValid(d, f)) {
        lastIC = ic;
        return new double[][] { coefficients, fractions };
    }
    return null;
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) CustomPowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer)

Example 4 with Weight

use of org.apache.commons.math3.optim.nonlinear.vector.Weight in project GDSC-SMLM by aherbert.

the class BinomialFitter method fitBinomial.

/**
	 * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a fixed n value and
	 * attempts to optimise p.
	 * 
	 * @param histogram
	 *            The input histogram
	 * @param mean
	 *            The histogram mean (used to estimate p). Calculated if NaN.
	 * @param n
	 *            The n to evaluate
	 * @param zeroTruncated
	 *            True if the model should ignore n=0 (zero-truncated binomial)
	 * @return The best fit (n, p)
	 * @throws IllegalArgumentException
	 *             If any of the input data values are negative
	 * @throws IllegalArgumentException
	 *             If any fitting a zero truncated binomial and there are no values above zero
	 */
public PointValuePair fitBinomial(double[] histogram, double mean, int n, boolean zeroTruncated) {
    if (Double.isNaN(mean))
        mean = getMean(histogram);
    if (zeroTruncated && histogram[0] > 0) {
        log("Fitting zero-truncated histogram but there are zero values - Renormalising to ignore zero");
        double cumul = 0;
        for (int i = 1; i < histogram.length; i++) cumul += histogram[i];
        if (cumul == 0)
            throw new IllegalArgumentException("Fitting zero-truncated histogram but there are no non-zero values");
        histogram[0] = 0;
        for (int i = 1; i < histogram.length; i++) histogram[i] /= cumul;
    }
    int nFittedPoints = Math.min(histogram.length, n + 1) - ((zeroTruncated) ? 1 : 0);
    if (nFittedPoints < 1) {
        log("No points to fit (%d): Histogram.length = %d, n = %d, zero-truncated = %b", nFittedPoints, histogram.length, n, zeroTruncated);
        return null;
    }
    // The model is only fitting the probability p
    // For a binomial n*p = mean => p = mean/n
    double[] initialSolution = new double[] { FastMath.min(mean / n, 1) };
    // Create the function
    BinomialModelFunction function = new BinomialModelFunction(histogram, n, zeroTruncated);
    double[] lB = new double[1];
    double[] uB = new double[] { 1 };
    SimpleBounds bounds = new SimpleBounds(lB, uB);
    // Fit
    // CMAESOptimizer or BOBYQAOptimizer support bounds
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    int maxIterations = 2000;
    //Double.NEGATIVE_INFINITY;
    double stopFitness = 0;
    boolean isActiveCMA = true;
    int diagonalOnly = 0;
    int checkFeasableCount = 1;
    RandomGenerator random = new Well19937c();
    boolean generateStatistics = false;
    ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
    OptimizationData sigma = new CMAESOptimizer.Sigma(new double[] { (uB[0] - lB[0]) / 3 });
    OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(2))));
    try {
        PointValuePair solution = null;
        boolean noRefit = maximumLikelihood;
        if (n == 1 && zeroTruncated) {
            // No need to fit
            solution = new PointValuePair(new double[] { 1 }, 0);
            noRefit = true;
        } else {
            GoalType goalType = (maximumLikelihood) ? GoalType.MAXIMIZE : GoalType.MINIMIZE;
            // Iteratively fit
            CMAESOptimizer opt = new CMAESOptimizer(maxIterations, stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
            for (int iteration = 0; iteration <= fitRestarts; iteration++) {
                try {
                    // Start from the initial solution
                    PointValuePair result = opt.optimize(new InitialGuess(initialSolution), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    //		opt.getEvaluations());
                    if (solution == null || result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (TooManyEvaluationsException e) {
                } catch (TooManyIterationsException e) {
                }
                if (solution == null)
                    continue;
                try {
                    // Also restart from the current optimum
                    PointValuePair result = opt.optimize(new InitialGuess(solution.getPointRef()), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    //		opt.getEvaluations());
                    if (result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (TooManyEvaluationsException e) {
                } catch (TooManyIterationsException e) {
                }
            }
            if (solution == null)
                return null;
        }
        if (noRefit) {
            // Although we fit the log-likelihood, return the sum-of-squares to allow 
            // comparison across different n
            double p = solution.getPointRef()[0];
            double ss = 0;
            double[] obs = function.p;
            double[] exp = function.getP(p);
            for (int i = 0; i < obs.length; i++) ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            return new PointValuePair(solution.getPointRef(), ss);
        } else // We can do a LVM refit if the number of fitted points is more than 1
        if (nFittedPoints > 1) {
            // Improve SS fit with a gradient based LVM optimizer
            LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
            try {
                final BinomialModelFunctionGradient gradientFunction = new BinomialModelFunctionGradient(histogram, n, zeroTruncated);
                //@formatter:off
                LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(solution.getPointRef()).target(gradientFunction.p).weight(new DiagonalMatrix(gradientFunction.getWeights())).model(gradientFunction, new MultivariateMatrixFunction() {

                    public double[][] value(double[] point) throws IllegalArgumentException {
                        return gradientFunction.jacobian(point);
                    }
                }).build();
                //@formatter:on
                Optimum lvmSolution = optimizer.optimize(problem);
                // Check the pValue is valid since the LVM is not bounded.
                double p = lvmSolution.getPoint().getEntry(0);
                if (p <= 1 && p >= 0) {
                    // True if the weights are 1
                    double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
                    //	ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
                    if (ss < solution.getValue()) {
                        //		Utils.rounded(100 * (solution.getValue() - ss) / solution.getValue(), 4));
                        return new PointValuePair(lvmSolution.getPoint().toArray(), ss);
                    }
                }
            } catch (TooManyIterationsException e) {
                log("Failed to re-fit: Too many iterations: %s", e.getMessage());
            } catch (ConvergenceException e) {
                log("Failed to re-fit: %s", e.getMessage());
            } catch (Exception e) {
            // Ignore this ...
            }
        }
        return solution;
    } catch (Exception e) {
        log("Failed to fit Binomial distribution with N=%d : %s", n, e.getMessage());
    }
    return null;
}
Also used : InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) Well19937c(org.apache.commons.math3.random.Well19937c) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 5 with Weight

use of org.apache.commons.math3.optim.nonlinear.vector.Weight in project GDSC-SMLM by aherbert.

the class ApacheLVMFitter method computeFit.

public FitStatus computeFit(double[] y, final double[] y_fit, double[] a, double[] a_dev) {
    int n = y.length;
    try {
        // Different convergence thresholds seem to have no effect on the resulting fit, only the number of
        // iterations for convergence
        final double initialStepBoundFactor = 100;
        final double costRelativeTolerance = 1e-10;
        final double parRelativeTolerance = 1e-10;
        final double orthoTolerance = 1e-10;
        final double threshold = Precision.SAFE_MIN;
        // Extract the parameters to be fitted
        final double[] initialSolution = getInitialSolution(a);
        // TODO - Pass in more advanced stopping criteria.
        // Create the target and weight arrays
        final double[] yd = new double[n];
        final double[] w = new double[n];
        for (int i = 0; i < n; i++) {
            yd[i] = y[i];
            w[i] = 1;
        }
        LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
        //@formatter:off
        LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd).weight(new DiagonalMatrix(w));
        if (f instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) f).canComputeValuesAndJacobian()) {
            // Compute together, or each individually
            builder.model(new ValueAndJacobianFunction() {

                final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) f;

                public Pair<RealVector, RealMatrix> value(RealVector point) {
                    final double[] p = point.toArray();
                    final Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
                    return new Pair<RealVector, RealMatrix>(new ArrayRealVector(result.getFirst(), false), new Array2DRowRealMatrix(result.getSecond(), false));
                }

                public RealVector computeValue(double[] params) {
                    return new ArrayRealVector(fun.computeValues(params), false);
                }

                public RealMatrix computeJacobian(double[] params) {
                    return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
                }
            });
        } else {
            // Compute separately
            builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) f, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) f, a, n));
        }
        LeastSquaresProblem problem = builder.build();
        Optimum optimum = optimizer.optimize(problem);
        final double[] parameters = optimum.getPoint().toArray();
        setSolution(a, parameters);
        iterations = optimum.getIterations();
        evaluations = optimum.getEvaluations();
        if (a_dev != null) {
            try {
                double[][] covar = optimum.getCovariances(threshold).getData();
                setDeviationsFromMatrix(a_dev, covar);
            } catch (SingularMatrixException e) {
                // Matrix inversion failed. In order to return a solution 
                // return the reciprocal of the diagonal of the Fisher information 
                // for a loose bound on the limit 
                final int[] gradientIndices = f.gradientIndices();
                final int nparams = gradientIndices.length;
                GradientCalculator calculator = GradientCalculatorFactory.newCalculator(nparams);
                double[][] alpha = new double[nparams][nparams];
                double[] beta = new double[nparams];
                calculator.findLinearised(nparams, y, a, alpha, beta, (NonLinearFunction) f);
                FisherInformationMatrix m = new FisherInformationMatrix(alpha);
                setDeviations(a_dev, m.crlb(true));
            }
        }
        // Compute function value
        if (y_fit != null) {
            Gaussian2DFunction f = (Gaussian2DFunction) this.f;
            f.initialise0(a);
            f.forEach(new ValueProcedure() {

                int i = 0;

                public void execute(double value) {
                    y_fit[i] = value;
                }
            });
        }
        // As this is unweighted then we can do this to get the sum of squared residuals
        // This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
        // just computes the dot product anyway.
        value = optimum.getResiduals().dotProduct(optimum.getResiduals());
    } catch (TooManyEvaluationsException e) {
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (TooManyIterationsException e) {
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (ConvergenceException e) {
        // Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (Exception e) {
        // TODO - Find out the other exceptions from the fitter and add return values to match. 
        return FitStatus.UNKNOWN;
    }
    return FitStatus.OK;
}
Also used : ValueProcedure(gdsc.smlm.function.ValueProcedure) ExtendedNonLinearFunction(gdsc.smlm.function.ExtendedNonLinearFunction) NonLinearFunction(gdsc.smlm.function.NonLinearFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Gaussian2DFunction(gdsc.smlm.function.gaussian.Gaussian2DFunction) ValueAndJacobianFunction(org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) RealVector(org.apache.commons.math3.linear.RealVector) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) GradientCalculator(gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator) Pair(org.apache.commons.math3.util.Pair) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) FisherInformationMatrix(gdsc.smlm.fitting.FisherInformationMatrix) MultivariateMatrixFunctionWrapper(gdsc.smlm.function.MultivariateMatrixFunctionWrapper) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MultivariateVectorFunctionWrapper(gdsc.smlm.function.MultivariateVectorFunctionWrapper) ExtendedNonLinearFunction(gdsc.smlm.function.ExtendedNonLinearFunction)

Aggregations

Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)31 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)31 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)30 LevenbergMarquardtOptimizer (org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer)30 DiagonalMatrix (org.apache.commons.math3.linear.DiagonalMatrix)29 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)26 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)26 PointValuePair (org.apache.commons.math3.optim.PointValuePair)16 RealVector (org.apache.commons.math3.linear.RealVector)14 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)12 MultivariateMatrixFunction (org.apache.commons.math3.analysis.MultivariateMatrixFunction)11 RealMatrix (org.apache.commons.math3.linear.RealMatrix)11 Pair (org.apache.commons.math3.util.Pair)11 Plot (ij.gui.Plot)10 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)9 ArrayList (java.util.ArrayList)8 List (java.util.List)8 PoissonDistribution (org.apache.commons.math3.distribution.PoissonDistribution)8 ArrayRealVector (org.apache.commons.math3.linear.ArrayRealVector)8 InitialGuess (org.apache.commons.math3.optim.InitialGuess)8