Search in sources :

Example 6 with Region

use of org.apache.commons.math3.geometry.partitioning.Region in project GDSC-SMLM by aherbert.

the class MaximumLikelihoodFitter method computeFit.

/*
	 * (non-Javadoc)
	 * 
	 * @see gdsc.smlm.fitting.nonlinear.BaseFunctionSolver#computeFit(double[], double[], double[], double[])
	 */
public FitStatus computeFit(double[] y, double[] y_fit, double[] a, double[] a_dev) {
    final int n = y.length;
    LikelihoodWrapper maximumLikelihoodFunction = createLikelihoodWrapper((NonLinearFunction) f, n, y, a);
    @SuppressWarnings("rawtypes") BaseOptimizer baseOptimiser = null;
    try {
        double[] startPoint = getInitialSolution(a);
        PointValuePair optimum = null;
        if (searchMethod == SearchMethod.POWELL || searchMethod == SearchMethod.POWELL_BOUNDED || searchMethod == SearchMethod.POWELL_ADAPTER) {
            // Non-differentiable version using Powell Optimiser
            // This is as per the method in Numerical Recipes 10.5 (Direction Set (Powell's) method)
            // I could extend the optimiser and implement bounds on the directions moved. However the mapping
            // adapter seems to work OK.
            final boolean basisConvergence = false;
            // Perhaps these thresholds should be tighter?
            // The default is to use the sqrt() of the overall tolerance
            //final double lineRel = FastMath.sqrt(relativeThreshold);
            //final double lineAbs = FastMath.sqrt(absoluteThreshold);
            //final double lineRel = relativeThreshold * 1e2;
            //final double lineAbs = absoluteThreshold * 1e2;
            // Since we are fitting only a small number of parameters then just use the same tolerance 
            // for each search direction
            final double lineRel = relativeThreshold;
            final double lineAbs = absoluteThreshold;
            CustomPowellOptimizer o = new CustomPowellOptimizer(relativeThreshold, absoluteThreshold, lineRel, lineAbs, null, basisConvergence);
            baseOptimiser = o;
            OptimizationData maxIterationData = null;
            if (getMaxIterations() > 0)
                maxIterationData = new MaxIter(getMaxIterations());
            if (searchMethod == SearchMethod.POWELL_ADAPTER) {
                // Try using the mapping adapter for a bounded Powell search
                MultivariateFunctionMappingAdapter adapter = new MultivariateFunctionMappingAdapter(new MultivariateLikelihood(maximumLikelihoodFunction), lower, upper);
                optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(adapter), GoalType.MINIMIZE, new InitialGuess(adapter.boundedToUnbounded(startPoint)));
                double[] solution = adapter.unboundedToBounded(optimum.getPointRef());
                optimum = new PointValuePair(solution, optimum.getValue());
            } else {
                if (powellFunction == null) {
                    // Python code by using the sqrt of the number of photons and background.
                    if (mapGaussian) {
                        Gaussian2DFunction gf = (Gaussian2DFunction) f;
                        // Re-map signal and background using the sqrt
                        int[] indices = gf.gradientIndices();
                        int[] map = new int[indices.length];
                        int count = 0;
                        // Background is always first
                        if (indices[0] == Gaussian2DFunction.BACKGROUND) {
                            map[count++] = 0;
                        }
                        // Look for the Signal in multiple peak 2D Gaussians
                        for (int i = 1; i < indices.length; i++) if (indices[i] % 6 == Gaussian2DFunction.SIGNAL) {
                            map[count++] = i;
                        }
                        if (count > 0) {
                            powellFunction = new MappedMultivariateLikelihood(maximumLikelihoodFunction, Arrays.copyOf(map, count));
                        }
                    }
                    if (powellFunction == null) {
                        powellFunction = new MultivariateLikelihood(maximumLikelihoodFunction);
                    }
                }
                // Update the maximum likelihood function in the Powell function wrapper
                powellFunction.fun = maximumLikelihoodFunction;
                OptimizationData positionChecker = null;
                // new org.apache.commons.math3.optim.PositionChecker(relativeThreshold, absoluteThreshold);
                SimpleBounds simpleBounds = null;
                if (powellFunction.isMapped()) {
                    MappedMultivariateLikelihood adapter = (MappedMultivariateLikelihood) powellFunction;
                    if (searchMethod == SearchMethod.POWELL_BOUNDED)
                        simpleBounds = new SimpleBounds(adapter.map(lower), adapter.map(upper));
                    optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(adapter.map(startPoint)), positionChecker, simpleBounds);
                    double[] solution = adapter.unmap(optimum.getPointRef());
                    optimum = new PointValuePair(solution, optimum.getValue());
                } else {
                    if (searchMethod == SearchMethod.POWELL_BOUNDED)
                        simpleBounds = new SimpleBounds(lower, upper);
                    optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(startPoint), positionChecker, simpleBounds);
                }
            }
        } else if (searchMethod == SearchMethod.BOBYQA) {
            // Differentiable approximation using Powell's BOBYQA algorithm.
            // This is slower than the Powell optimiser and requires a high number of evaluations.
            int numberOfInterpolationPoints = this.getNumberOfFittedParameters() + 2;
            BOBYQAOptimizer o = new BOBYQAOptimizer(numberOfInterpolationPoints);
            baseOptimiser = o;
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lower, upper));
        } else if (searchMethod == SearchMethod.CMAES) {
            // TODO - Understand why the CMAES optimiser does not fit very well on test data. It appears 
            // to converge too early and the likelihood scores are not as low as the other optimisers.
            // CMAESOptimiser based on Matlab code:
            // https://www.lri.fr/~hansen/cmaes.m
            // Take the defaults from the Matlab documentation
            //Double.NEGATIVE_INFINITY;
            double stopFitness = 0;
            boolean isActiveCMA = true;
            int diagonalOnly = 0;
            int checkFeasableCount = 1;
            RandomGenerator random = new Well19937c();
            boolean generateStatistics = false;
            // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
            double[] sigma = new double[lower.length];
            for (int i = 0; i < sigma.length; i++) sigma[i] = (upper[i] - lower[i]) / 3;
            int popSize = (int) (4 + Math.floor(3 * Math.log(sigma.length)));
            // The CMAES optimiser is random and restarting can overcome problems with quick convergence.
            // The Apache commons documentations states that convergence should occur between 30N and 300N^2
            // function evaluations
            final int n30 = FastMath.min(sigma.length * sigma.length * 30, getMaxEvaluations() / 2);
            evaluations = 0;
            OptimizationData[] data = new OptimizationData[] { new InitialGuess(startPoint), new CMAESOptimizer.PopulationSize(popSize), new MaxEval(getMaxEvaluations()), new CMAESOptimizer.Sigma(sigma), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new SimpleBounds(lower, upper) };
            // Iterate to prevent early convergence
            int repeat = 0;
            while (evaluations < n30) {
                if (repeat++ > 1) {
                    // Update the start point and population size
                    data[0] = new InitialGuess(optimum.getPointRef());
                    popSize *= 2;
                    data[1] = new CMAESOptimizer.PopulationSize(popSize);
                }
                CMAESOptimizer o = new CMAESOptimizer(getMaxIterations(), stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
                baseOptimiser = o;
                PointValuePair result = o.optimize(data);
                iterations += o.getIterations();
                evaluations += o.getEvaluations();
                //		o.getEvaluations(), totalEvaluations);
                if (optimum == null || result.getValue() < optimum.getValue()) {
                    optimum = result;
                }
            }
            // Prevent incrementing the iterations again
            baseOptimiser = null;
        } else if (searchMethod == SearchMethod.BFGS) {
            // BFGS can use an approximate line search minimisation where as Powell and conjugate gradient
            // methods require a more accurate line minimisation. The BFGS search does not do a full 
            // minimisation but takes appropriate steps in the direction of the current gradient.
            // Do not use the convergence checker on the value of the function. Use the convergence on the 
            // point coordinate and gradient
            //BFGSOptimizer o = new BFGSOptimizer(new SimpleValueChecker(rel, abs));
            BFGSOptimizer o = new BFGSOptimizer();
            baseOptimiser = o;
            // Configure maximum step length for each dimension using the bounds
            double[] stepLength = new double[lower.length];
            for (int i = 0; i < stepLength.length; i++) {
                stepLength[i] = (upper[i] - lower[i]) * 0.3333333;
                if (stepLength[i] <= 0)
                    stepLength[i] = Double.POSITIVE_INFINITY;
            }
            // The GoalType is always minimise so no need to pass this in
            OptimizationData positionChecker = null;
            //new org.apache.commons.math3.optim.PositionChecker(relativeThreshold, absoluteThreshold);
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BFGSOptimizer.GradientTolerance(relativeThreshold), positionChecker, new BFGSOptimizer.StepLength(stepLength));
        } else {
            // The line search algorithm often fails. This is due to searching into a region where the 
            // function evaluates to a negative so has been clipped. This means the upper bound of the line
            // cannot be found.
            // Note that running it on an easy problem (200 photons with fixed fitting (no background)) the algorithm
            // does sometimes produces results better than the Powell algorithm but it is slower.
            BoundedNonLinearConjugateGradientOptimizer o = new BoundedNonLinearConjugateGradientOptimizer((searchMethod == SearchMethod.CONJUGATE_GRADIENT_FR) ? Formula.FLETCHER_REEVES : Formula.POLAK_RIBIERE, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
            baseOptimiser = o;
            // Note: The gradients may become unstable at the edge of the bounds. Or they will not change 
            // direction if the true solution is on the bounds since the gradient will always continue 
            // towards the bounds. This is key to the conjugate gradient method. It searches along a vector 
            // until the direction of the gradient is in the opposite direction (using dot products, i.e. 
            // cosine of angle between them)
            // NR 10.7 states there is no advantage of the variable metric DFP or BFGS methods over
            // conjugate gradient methods. So I will try these first.
            // Try this:
            // Adapt the conjugate gradient optimiser to use the gradient to pick the search direction
            // and then for the line minimisation. However if the function is out of bounds then clip the 
            // variables at the bounds and continue. 
            // If the current point is at the bounds and the gradient is to continue out of bounds then 
            // clip the gradient too.
            // Or: just use the gradient for the search direction then use the line minimisation/rest
            // as per the Powell optimiser. The bounds should limit the search.
            // I tried a Bounded conjugate gradient optimiser with clipped variables:
            // This sometimes works. However when the variables go a long way out of the expected range the gradients
            // can have vastly different magnitudes. This results in the algorithm stalling since the gradients
            // can be close to zero and the some of the parameters are no longer adjusted.
            // Perhaps this can be looked for and the algorithm then gives up and resorts to a Powell optimiser from 
            // the current point.
            // Changed the bracketing step to very small (default is 1, changed to 0.001). This improves the 
            // performance. The gradient direction is very sensitive to small changes in the coordinates so a 
            // tighter bracketing of the line search helps.
            // Tried using a non-gradient method for the line search copied from the Powell optimiser:
            // This also works when the bracketing step is small but the number of iterations is higher.
            // 24.10.2014: I have tried to get conjugate gradient to work but the gradient function 
            // must not behave suitably for the optimiser. In the current state both methods of using a 
            // Bounded Conjugate Gradient Optimiser perform poorly relative to other optimisers:
            // Simulated : n=1000, signal=200, x=0.53, y=0.47
            // LVM : n=1000, signal=171, x=0.537, y=0.471 (1.003s)
            // Powell : n=1000, signal=187, x=0.537, y=0.48 (1.238s)
            // Gradient based PR (constrained): n=858, signal=161, x=0.533, y=0.474 (2.54s)
            // Gradient based PR (bounded): n=948, signal=161, x=0.533, y=0.473 (2.67s)
            // Non-gradient based : n=1000, signal=151.47, x=0.535, y=0.474 (1.626s)
            // The conjugate optimisers are slower, under predict the signal by the most and in the case of 
            // the gradient based optimiser, fail to converge on some problems. This is worse when constrained
            // fitting is used and not tightly bounded fitting.
            // I will leave the code in as an option but would not recommend using it. I may remove it in the 
            // future.
            // Note: It is strange that the non-gradient based line minimisation is more successful.
            // It may be that the gradient function is not accurate (due to round off error) or that it is
            // simply wrong when far from the optimum. My JUnit tests only evaluate the function within the 
            // expected range of the answer.
            // Note the default step size on the Powell optimiser is 1 but the initial directions are unit vectors.
            // So our bracketing step should be a minimum of 1 / average length of the first gradient vector to prevent
            // the first step being too large when bracketing.
            final double[] gradient = new double[startPoint.length];
            maximumLikelihoodFunction.likelihood(startPoint, gradient);
            double l = 0;
            for (double d : gradient) l += d * d;
            final double bracketingStep = FastMath.min(0.001, ((l > 1) ? 1.0 / l : 1));
            //System.out.printf("Bracketing step = %f (length=%f)\n", bracketingStep, l);
            o.setUseGradientLineSearch(gradientLineMinimisation);
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BoundedNonLinearConjugateGradientOptimizer.BracketingStep(bracketingStep));
        //maximumLikelihoodFunction.value(solution, gradient);
        //System.out.printf("Iter = %d, %g @ %s : %s\n", iterations, ll, Arrays.toString(solution),
        //		Arrays.toString(gradient));
        }
        final double[] solution = optimum.getPointRef();
        setSolution(a, solution);
        if (a_dev != null) {
            // Assume the Maximum Likelihood estimator returns the optimum fit (achieves the Cramer Roa
            // lower bounds) and so the covariance can be obtained from the Fisher Information Matrix.
            FisherInformationMatrix m = new FisherInformationMatrix(maximumLikelihoodFunction.fisherInformation(a));
            setDeviations(a_dev, m.crlb(true));
        }
        // Reverse negative log likelihood for maximum likelihood score
        value = -optimum.getValue();
    } catch (TooManyIterationsException e) {
        //e.printStackTrace();
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (TooManyEvaluationsException e) {
        //e.printStackTrace();
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (ConvergenceException e) {
        //System.out.printf("Singular non linear model = %s\n", e.getMessage());
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (BFGSOptimizer.LineSearchRoundoffException e) {
        //e.printStackTrace();
        return FitStatus.FAILED_TO_CONVERGE;
    } catch (Exception e) {
        //System.out.printf("Unknown error = %s\n", e.getMessage());
        e.printStackTrace();
        return FitStatus.UNKNOWN;
    } finally {
        if (baseOptimiser != null) {
            iterations += baseOptimiser.getIterations();
            evaluations += baseOptimiser.getEvaluations();
        }
    }
    // Check this as likelihood functions can go wrong
    if (Double.isInfinite(value) || Double.isNaN(value))
        return FitStatus.INVALID_LIKELIHOOD;
    return FitStatus.OK;
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) BOBYQAOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) Well19937c(org.apache.commons.math3.random.Well19937c) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) BFGSOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.gradient.BFGSOptimizer) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Gaussian2DFunction(gdsc.smlm.function.gaussian.Gaussian2DFunction) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) BoundedNonLinearConjugateGradientOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) BaseOptimizer(org.apache.commons.math3.optim.BaseOptimizer) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) FisherInformationMatrix(gdsc.smlm.fitting.FisherInformationMatrix) PoissonGammaGaussianLikelihoodWrapper(gdsc.smlm.function.PoissonGammaGaussianLikelihoodWrapper) PoissonGaussianLikelihoodWrapper(gdsc.smlm.function.PoissonGaussianLikelihoodWrapper) PoissonLikelihoodWrapper(gdsc.smlm.function.PoissonLikelihoodWrapper) LikelihoodWrapper(gdsc.smlm.function.LikelihoodWrapper) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) MultivariateFunctionMappingAdapter(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter) OptimizationData(org.apache.commons.math3.optim.OptimizationData) CustomPowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 7 with Region

use of org.apache.commons.math3.geometry.partitioning.Region in project MindsEye by SimiaCryptus.

the class ObjectLocation method run.

/**
 * Run.
 *
 * @param log the log
 */
public void run(@Nonnull final NotebookOutput log) {
    @Nonnull String logName = "cuda_" + log.getName() + ".log";
    log.p(log.file((String) null, logName, "GPU Log"));
    CudaSystem.addLog(new PrintStream(log.file(logName)));
    ImageClassifier classifier = getClassifierNetwork();
    Layer classifyNetwork = classifier.getNetwork();
    ImageClassifier locator = getLocatorNetwork();
    Layer locatorNetwork = locator.getNetwork();
    ArtistryUtil.setPrecision((DAGNetwork) classifyNetwork, Precision.Float);
    ArtistryUtil.setPrecision((DAGNetwork) locatorNetwork, Precision.Float);
    Tensor[][] inputData = loadImages_library();
    // Tensor[][] inputData = loadImage_Caltech101(log);
    double alphaPower = 0.8;
    final AtomicInteger index = new AtomicInteger(0);
    Arrays.stream(inputData).limit(10).forEach(row -> {
        log.h3("Image " + index.getAndIncrement());
        final Tensor img = row[0];
        log.p(log.image(img.toImage(), ""));
        Result classifyResult = classifyNetwork.eval(new MutableResult(row));
        Result locationResult = locatorNetwork.eval(new MutableResult(row));
        Tensor classification = classifyResult.getData().get(0);
        List<CharSequence> categories = classifier.getCategories();
        int[] sortedIndices = IntStream.range(0, categories.size()).mapToObj(x -> x).sorted(Comparator.comparing(i -> -classification.get(i))).mapToInt(x -> x).limit(10).toArray();
        logger.info(Arrays.stream(sortedIndices).mapToObj(i -> String.format("%s: %s = %s%%", i, categories.get(i), classification.get(i) * 100)).reduce((a, b) -> a + "\n" + b).orElse(""));
        Map<CharSequence, Tensor> vectors = new HashMap<>();
        List<CharSequence> predictionList = Arrays.stream(sortedIndices).mapToObj(categories::get).collect(Collectors.toList());
        Arrays.stream(sortedIndices).limit(10).forEach(category -> {
            CharSequence name = categories.get(category);
            log.h3(name);
            Tensor alphaTensor = renderAlpha(alphaPower, img, locationResult, classification, category);
            log.p(log.image(img.toRgbImageAlphaMask(0, 1, 2, alphaTensor), ""));
            vectors.put(name, alphaTensor.unit());
        });
        Tensor avgDetection = vectors.values().stream().reduce((a, b) -> a.add(b)).get().scale(1.0 / vectors.size());
        Array2DRowRealMatrix covarianceMatrix = new Array2DRowRealMatrix(predictionList.size(), predictionList.size());
        for (int x = 0; x < predictionList.size(); x++) {
            for (int y = 0; y < predictionList.size(); y++) {
                Tensor l = vectors.get(predictionList.get(x)).minus(avgDetection);
                Tensor r = vectors.get(predictionList.get(y)).minus(avgDetection);
                covarianceMatrix.setEntry(x, y, l.dot(r));
            }
        }
        @Nonnull final EigenDecomposition decomposition = new EigenDecomposition(covarianceMatrix);
        for (int objectVector = 0; objectVector < 10; objectVector++) {
            log.h3("Eigenobject " + objectVector);
            double eigenvalue = decomposition.getRealEigenvalue(objectVector);
            RealVector eigenvector = decomposition.getEigenvector(objectVector);
            Tensor detectionRegion = IntStream.range(0, eigenvector.getDimension()).mapToObj(i -> vectors.get(predictionList.get(i)).scale(eigenvector.getEntry(i))).reduce((a, b) -> a.add(b)).get();
            detectionRegion = detectionRegion.scale(255.0 / detectionRegion.rms());
            CharSequence categorization = IntStream.range(0, eigenvector.getDimension()).mapToObj(i -> {
                CharSequence category = predictionList.get(i);
                double component = eigenvector.getEntry(i);
                return String.format("<li>%s = %.4f</li>", category, component);
            }).reduce((a, b) -> a + "" + b).get();
            log.p(String.format("Object Detected: <ol>%s</ol>", categorization));
            log.p("Object Eigenvalue: " + eigenvalue);
            log.p("Object Region: " + log.image(img.toRgbImageAlphaMask(0, 1, 2, detectionRegion), ""));
            log.p("Object Region Compliment: " + log.image(img.toRgbImageAlphaMask(0, 1, 2, detectionRegion.scale(-1)), ""));
        }
        // final int[] orderedVectors = IntStream.range(0, 10).mapToObj(x -> x)
        // .sorted(Comparator.comparing(x -> -decomposition.getRealEigenvalue(x))).mapToInt(x -> x).toArray();
        // IntStream.range(0, orderedVectors.length)
        // .mapToObj(i -> {
        // //double realEigenvalue = decomposition.getRealEigenvalue(orderedVectors[i]);
        // return decomposition.getEigenvector(orderedVectors[i]).toArray();
        // }
        // ).toArray(i -> new double[i][]);
        log.p(String.format("<table><tr><th>Cosine Distance</th>%s</tr>%s</table>", Arrays.stream(sortedIndices).limit(10).mapToObj(col -> "<th>" + categories.get(col) + "</th>").reduce((a, b) -> a + b).get(), Arrays.stream(sortedIndices).limit(10).mapToObj(r -> {
            return String.format("<tr><td>%s</td>%s</tr>", categories.get(r), Arrays.stream(sortedIndices).limit(10).mapToObj(col -> {
                return String.format("<td>%.4f</td>", Math.acos(vectors.get(categories.get(r)).dot(vectors.get(categories.get(col)))));
            }).reduce((a, b) -> a + b).get());
        }).reduce((a, b) -> a + b).orElse("")));
    });
    log.setFrontMatterProperty("status", "OK");
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) VGG16_HDF5(com.simiacryptus.mindseye.models.VGG16_HDF5) HashMap(java.util.HashMap) RealVector(org.apache.commons.math3.linear.RealVector) Caltech101(com.simiacryptus.mindseye.test.data.Caltech101) Result(com.simiacryptus.mindseye.lang.Result) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) VGG19_HDF5(com.simiacryptus.mindseye.models.VGG19_HDF5) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) PrintStream(java.io.PrintStream) Util(com.simiacryptus.util.Util) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.cudnn.SoftmaxActivationLayer) Logger(org.slf4j.Logger) BufferedImage(java.awt.image.BufferedImage) BandReducerLayer(com.simiacryptus.mindseye.layers.cudnn.BandReducerLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) ConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) Hdf5Archive(com.simiacryptus.mindseye.models.Hdf5Archive) List(java.util.List) Stream(java.util.stream.Stream) CudaSystem(com.simiacryptus.mindseye.lang.cudnn.CudaSystem) EigenDecomposition(org.apache.commons.math3.linear.EigenDecomposition) PoolingLayer(com.simiacryptus.mindseye.layers.cudnn.PoolingLayer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Comparator(java.util.Comparator) PrintStream(java.io.PrintStream) Tensor(com.simiacryptus.mindseye.lang.Tensor) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) Layer(com.simiacryptus.mindseye.lang.Layer) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.cudnn.SoftmaxActivationLayer) BandReducerLayer(com.simiacryptus.mindseye.layers.cudnn.BandReducerLayer) ConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer) PoolingLayer(com.simiacryptus.mindseye.layers.cudnn.PoolingLayer) Result(com.simiacryptus.mindseye.lang.Result) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) EigenDecomposition(org.apache.commons.math3.linear.EigenDecomposition) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RealVector(org.apache.commons.math3.linear.RealVector)

Example 8 with Region

use of org.apache.commons.math3.geometry.partitioning.Region in project GDSC-SMLM by aherbert.

the class PCPALMFitting method runBoundedOptimiser.

private PointValuePair runBoundedOptimiser(double[][] gr, double[] initialSolution, double[] lB, double[] uB, SumOfSquaresModelFunction function) {
    // Create the functions to optimise
    ObjectiveFunction objective = new ObjectiveFunction(new SumOfSquaresMultivariateFunction(function));
    ObjectiveFunctionGradient gradient = new ObjectiveFunctionGradient(new SumOfSquaresMultivariateVectorFunction(function));
    final boolean debug = false;
    // Try a BFGS optimiser since this will produce a deterministic solution and can respect bounds.
    PointValuePair optimum = null;
    boundedEvaluations = 0;
    final MaxEval maxEvaluations = new MaxEval(2000);
    MultivariateOptimizer opt = null;
    for (int iteration = 0; iteration <= fitRestarts; iteration++) {
        try {
            opt = new BFGSOptimizer();
            final double relativeThreshold = 1e-6;
            // Configure maximum step length for each dimension using the bounds
            double[] stepLength = new double[lB.length];
            for (int i = 0; i < stepLength.length; i++) stepLength[i] = (uB[i] - lB[i]) * 0.3333333;
            // The GoalType is always minimise so no need to pass this in
            optimum = opt.optimize(maxEvaluations, gradient, objective, new InitialGuess((optimum == null) ? initialSolution : optimum.getPointRef()), new SimpleBounds(lB, uB), new BFGSOptimizer.GradientTolerance(relativeThreshold), new BFGSOptimizer.StepLength(stepLength));
            if (debug)
                System.out.printf("BFGS Iter %d = %g (%d)\n", iteration, optimum.getValue(), opt.getEvaluations());
        } catch (TooManyEvaluationsException e) {
            // No need to restart
            break;
        } catch (RuntimeException e) {
            // No need to restart
            break;
        } finally {
            boundedEvaluations += opt.getEvaluations();
        }
    }
    // Try a CMAES optimiser which is non-deterministic. To overcome this we perform restarts.
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    //Double.NEGATIVE_INFINITY;
    double stopFitness = 0;
    boolean isActiveCMA = true;
    int diagonalOnly = 0;
    int checkFeasableCount = 1;
    //Well19937c();
    RandomGenerator random = new Well44497b();
    boolean generateStatistics = false;
    ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
    double[] range = new double[lB.length];
    for (int i = 0; i < lB.length; i++) range[i] = (uB[i] - lB[i]) / 3;
    OptimizationData sigma = new CMAESOptimizer.Sigma(range);
    OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(initialSolution.length))));
    SimpleBounds bounds = new SimpleBounds(lB, uB);
    opt = new CMAESOptimizer(maxEvaluations.getMaxEval(), stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
    // Restart the optimiser several times and take the best answer.
    for (int iteration = 0; iteration <= fitRestarts; iteration++) {
        try {
            // Start from the initial solution
            PointValuePair constrainedSolution = opt.optimize(new InitialGuess(initialSolution), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug)
                System.out.printf("CMAES Iter %d initial = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            boundedEvaluations += opt.getEvaluations();
            if (optimum == null || constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (TooManyEvaluationsException e) {
        } catch (TooManyIterationsException e) {
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
        if (optimum == null)
            continue;
        try {
            // Also restart from the current optimum
            PointValuePair constrainedSolution = opt.optimize(new InitialGuess(optimum.getPointRef()), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug)
                System.out.printf("CMAES Iter %d restart = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            if (constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (TooManyEvaluationsException e) {
        } catch (TooManyIterationsException e) {
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
    }
    return optimum;
}
Also used : MultivariateOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer) MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) BFGSOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.gradient.BFGSOptimizer) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) Well44497b(org.apache.commons.math3.random.Well44497b) OptimizationData(org.apache.commons.math3.optim.OptimizationData)

Example 9 with Region

use of org.apache.commons.math3.geometry.partitioning.Region in project GDSC-SMLM by aherbert.

the class FIRE method runQEstimation.

private void runQEstimation() {
    IJ.showStatus(TITLE + " ...");
    if (!showQEstimationInputDialog())
        return;
    MemoryPeakResults results = ResultsManager.loadInputResults(inputOption, false);
    if (results == null || results.size() == 0) {
        IJ.error(TITLE, "No results could be loaded");
        return;
    }
    if (results.getCalibration() == null) {
        IJ.error(TITLE, "The results are not calibrated");
        return;
    }
    results = cropToRoi(results);
    if (results.size() < 2) {
        IJ.error(TITLE, "No results within the crop region");
        return;
    }
    initialise(results, null);
    // We need localisation precision.
    // Build a histogram of the localisation precision.
    // Get the initial mean and SD and plot as a Gaussian.
    PrecisionHistogram histogram = calculatePrecisionHistogram();
    if (histogram == null) {
        IJ.error(TITLE, "No localisation precision available.\n \nPlease choose " + PrecisionMethod.FIXED + " and enter a precision mean and SD.");
        return;
    }
    StoredDataStatistics precision = histogram.precision;
    //String name = results.getName();
    double fourierImageScale = SCALE_VALUES[imageScaleIndex];
    int imageSize = IMAGE_SIZE_VALUES[imageSizeIndex];
    // Create the image and compute the numerator of FRC. 
    // Do not use the signal so results.size() is the number of localisations.
    IJ.showStatus("Computing FRC curve ...");
    FireImages images = createImages(fourierImageScale, imageSize, false);
    // DEBUGGING - Save the two images to disk. Load the images into the Matlab 
    // code that calculates the Q-estimation and make this plugin match the functionality.
    //IJ.save(new ImagePlus("i1", images.ip1), "/scratch/i1.tif");
    //IJ.save(new ImagePlus("i2", images.ip2), "/scratch/i2.tif");
    FRC frc = new FRC();
    frc.progress = progress;
    frc.setFourierMethod(fourierMethod);
    frc.setSamplingMethod(samplingMethod);
    frc.setPerimeterSamplingFactor(perimeterSamplingFactor);
    FRCCurve frcCurve = frc.calculateFrcCurve(images.ip1, images.ip2, images.nmPerPixel);
    if (frcCurve == null) {
        IJ.error(TITLE, "Failed to compute FRC curve");
        return;
    }
    IJ.showStatus("Running Q-estimation ...");
    // Note:
    // The method implemented here is based on Matlab code provided by Bernd Rieger.
    // The idea is to compute the spurious correlation component of the FRC Numerator
    // using an initial estimate of distribution of the localisation precision (assumed 
    // to be Gaussian). This component is the contribution of repeat localisations of 
    // the same molecule to the numerator and is modelled as an exponential decay
    // (exp_decay). The component is scaled by the Q-value which
    // is the average number of times a molecule is seen in addition to the first time.
    // At large spatial frequencies the scaled component should match the numerator,
    // i.e. at high resolution (low FIRE number) the numerator is made up of repeat 
    // localisations of the same molecule and not actual structure in the image.
    // The best fit is where the numerator equals the scaled component, i.e. num / (q*exp_decay) == 1.
    // The FRC Numerator is plotted and Q can be determined by
    // adjusting Q and the precision mean and SD to maximise the cost function.
    // This can be done interactively by the user with the effect on the FRC curve
    // dynamically updated and displayed.
    // Compute the scaled FRC numerator
    double qNorm = (1 / frcCurve.mean1 + 1 / frcCurve.mean2);
    double[] frcnum = new double[frcCurve.getSize()];
    for (int i = 0; i < frcnum.length; i++) {
        FRCCurveResult r = frcCurve.get(i);
        frcnum[i] = qNorm * r.getNumerator() / r.getNumberOfSamples();
    }
    // Compute the spatial frequency and the region for curve fitting
    double[] q = FRC.computeQ(frcCurve, false);
    int low = 0, high = q.length;
    while (high > 0 && q[high - 1] > maxQ) high--;
    while (low < q.length && q[low] < minQ) low++;
    // Require we fit at least 10% of the curve
    if (high - low < q.length * 0.1) {
        IJ.error(TITLE, "Not enough points for Q estimation");
        return;
    }
    // Obtain initial estimate of Q plateau height and decay.
    // This can be done by fitting the precision histogram and then fixing the mean and sigma.
    // Or it can be done by allowing the precision to be sampled and the mean and sigma
    // become parameters for fitting.
    // Check if we can sample precision values
    boolean sampleDecay = precision != null && FIRE.sampleDecay;
    double[] exp_decay;
    if (sampleDecay) {
        // Random sample of precision values from the distribution is used to 
        // construct the decay curve
        int[] sample = Random.sample(10000, precision.getN(), new Well19937c());
        final double four_pi2 = 4 * Math.PI * Math.PI;
        double[] pre = new double[q.length];
        for (int i = 1; i < q.length; i++) pre[i] = -four_pi2 * q[i] * q[i];
        // Sample
        final int n = sample.length;
        double[] hq = new double[n];
        for (int j = 0; j < n; j++) {
            // Scale to SR pixels
            double s2 = precision.getValue(sample[j]) / images.nmPerPixel;
            s2 *= s2;
            for (int i = 1; i < q.length; i++) hq[i] += FastMath.exp(pre[i] * s2);
        }
        for (int i = 1; i < q.length; i++) hq[i] /= n;
        exp_decay = new double[q.length];
        exp_decay[0] = 1;
        for (int i = 1; i < q.length; i++) {
            double sinc_q = sinc(Math.PI * q[i]);
            exp_decay[i] = sinc_q * sinc_q * hq[i];
        }
    } else {
        // Note: The sigma mean and std should be in the units of super-resolution 
        // pixels so scale to SR pixels
        exp_decay = computeExpDecay(histogram.mean / images.nmPerPixel, histogram.sigma / images.nmPerPixel, q);
    }
    // Smoothing
    double[] smooth;
    if (loessSmoothing) {
        // Note: This computes the log then smooths it 
        double bandwidth = 0.1;
        int robustness = 0;
        double[] l = new double[exp_decay.length];
        for (int i = 0; i < l.length; i++) {
            // Original Matlab code computes the log for each array.
            // This is equivalent to a single log on the fraction of the two.
            // Perhaps the two log method is more numerically stable.
            //l[i] = Math.log(Math.abs(frcnum[i])) - Math.log(exp_decay[i]);
            l[i] = Math.log(Math.abs(frcnum[i] / exp_decay[i]));
        }
        try {
            LoessInterpolator loess = new LoessInterpolator(bandwidth, robustness);
            smooth = loess.smooth(q, l);
        } catch (Exception e) {
            IJ.error(TITLE, "LOESS smoothing failed");
            return;
        }
    } else {
        // Note: This smooths the curve before computing the log 
        double[] norm = new double[exp_decay.length];
        for (int i = 0; i < norm.length; i++) {
            norm[i] = frcnum[i] / exp_decay[i];
        }
        // Median window of 5 == radius of 2
        MedianWindow mw = new MedianWindow(norm, 2);
        smooth = new double[exp_decay.length];
        for (int i = 0; i < norm.length; i++) {
            smooth[i] = Math.log(Math.abs(mw.getMedian()));
            mw.increment();
        }
    }
    // Fit with quadratic to find the initial guess.
    // Note: example Matlab code frc_Qcorrection7.m identifies regions of the 
    // smoothed log curve with low derivative and only fits those. The fit is 
    // used for the final estimate. Fitting a subset with low derivative is not 
    // implemented here since the initial estimate is subsequently optimised 
    // to maximise a cost function. 
    Quadratic curve = new Quadratic();
    SimpleCurveFitter fit = SimpleCurveFitter.create(curve, new double[2]);
    WeightedObservedPoints points = new WeightedObservedPoints();
    for (int i = low; i < high; i++) points.add(q[i], smooth[i]);
    double[] estimate = fit.fit(points.toList());
    double qValue = FastMath.exp(estimate[0]);
    //System.out.printf("Initial q-estimate = %s => %.3f\n", Arrays.toString(estimate), qValue);
    // This could be made an option. Just use for debugging
    boolean debug = false;
    if (debug) {
        // Plot the initial fit and the fit curve
        double[] qScaled = FRC.computeQ(frcCurve, true);
        double[] line = new double[q.length];
        for (int i = 0; i < q.length; i++) line[i] = curve.value(q[i], estimate);
        String title = TITLE + " Initial fit";
        Plot2 plot = new Plot2(title, "Spatial Frequency (nm^-1)", "FRC Numerator");
        String label = String.format("Q = %.3f", qValue);
        plot.addPoints(qScaled, smooth, Plot.LINE);
        plot.setColor(Color.red);
        plot.addPoints(qScaled, line, Plot.LINE);
        plot.setColor(Color.black);
        plot.addLabel(0, 0, label);
        Utils.display(title, plot, Utils.NO_TO_FRONT);
    }
    if (fitPrecision) {
        // Q - Should this be optional?
        if (sampleDecay) {
            // If a sample of the precision was used to construct the data for the initial fit 
            // then update the estimate using the fit result since it will be a better start point. 
            histogram.sigma = precision.getStandardDeviation();
            // Normalise sum-of-squares to the SR pixel size
            double meanSumOfSquares = (precision.getSumOfSquares() / (images.nmPerPixel * images.nmPerPixel)) / precision.getN();
            histogram.mean = images.nmPerPixel * Math.sqrt(meanSumOfSquares - estimate[1] / (4 * Math.PI * Math.PI));
        }
        // Do a multivariate fit ...
        SimplexOptimizer opt = new SimplexOptimizer(1e-6, 1e-10);
        PointValuePair p = null;
        MultiPlateauness f = new MultiPlateauness(frcnum, q, low, high);
        double[] initial = new double[] { histogram.mean / images.nmPerPixel, histogram.sigma / images.nmPerPixel, qValue };
        p = findMin(p, opt, f, scale(initial, 0.1));
        p = findMin(p, opt, f, scale(initial, 0.5));
        p = findMin(p, opt, f, initial);
        p = findMin(p, opt, f, scale(initial, 2));
        p = findMin(p, opt, f, scale(initial, 10));
        if (p != null) {
            double[] point = p.getPointRef();
            histogram.mean = point[0] * images.nmPerPixel;
            histogram.sigma = point[1] * images.nmPerPixel;
            qValue = point[2];
        }
    } else {
        // If so then this should be optional.
        if (sampleDecay) {
            if (precisionMethod != PrecisionMethod.FIXED) {
                histogram.sigma = precision.getStandardDeviation();
                // Normalise sum-of-squares to the SR pixel size
                double meanSumOfSquares = (precision.getSumOfSquares() / (images.nmPerPixel * images.nmPerPixel)) / precision.getN();
                histogram.mean = images.nmPerPixel * Math.sqrt(meanSumOfSquares - estimate[1] / (4 * Math.PI * Math.PI));
            }
            exp_decay = computeExpDecay(histogram.mean / images.nmPerPixel, histogram.sigma / images.nmPerPixel, q);
        }
        // Estimate spurious component by promoting plateauness.
        // The Matlab code used random initial points for a Simplex optimiser.
        // A Brent line search should be pretty deterministic so do simple repeats.
        // However it will proceed downhill so if the initial point is wrong then 
        // it will find a sub-optimal result.
        UnivariateOptimizer o = new BrentOptimizer(1e-3, 1e-6);
        Plateauness f = new Plateauness(frcnum, exp_decay, low, high);
        UnivariatePointValuePair p = null;
        p = findMin(p, o, f, qValue, 0.1);
        p = findMin(p, o, f, qValue, 0.2);
        p = findMin(p, o, f, qValue, 0.333);
        p = findMin(p, o, f, qValue, 0.5);
        // Do some Simplex repeats as well
        SimplexOptimizer opt = new SimplexOptimizer(1e-6, 1e-10);
        p = findMin(p, opt, f, qValue * 0.1);
        p = findMin(p, opt, f, qValue * 0.5);
        p = findMin(p, opt, f, qValue);
        p = findMin(p, opt, f, qValue * 2);
        p = findMin(p, opt, f, qValue * 10);
        if (p != null)
            qValue = p.getPoint();
    }
    QPlot qplot = new QPlot(frcCurve, qValue, low, high);
    // Interactive dialog to estimate Q (blinking events per flourophore) using 
    // sliders for the mean and standard deviation of the localisation precision.
    showQEstimationDialog(histogram, qplot, frcCurve, images.nmPerPixel);
    IJ.showStatus(TITLE + " complete");
}
Also used : BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) Plot2(ij.gui.Plot2) Well19937c(org.apache.commons.math3.random.Well19937c) PointValuePair(org.apache.commons.math3.optim.PointValuePair) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) LoessInterpolator(org.apache.commons.math3.analysis.interpolation.LoessInterpolator) WeightedObservedPoints(org.apache.commons.math3.fitting.WeightedObservedPoints) SimplexOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer) MemoryPeakResults(gdsc.smlm.results.MemoryPeakResults) MedianWindow(gdsc.core.utils.MedianWindow) SimpleCurveFitter(org.apache.commons.math3.fitting.SimpleCurveFitter) FRCCurveResult(gdsc.smlm.ij.frc.FRC.FRCCurveResult) StoredDataStatistics(gdsc.core.utils.StoredDataStatistics) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) WeightedObservedPoint(org.apache.commons.math3.fitting.WeightedObservedPoint) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) FRCCurve(gdsc.smlm.ij.frc.FRC.FRCCurve) FRC(gdsc.smlm.ij.frc.FRC) UnivariateOptimizer(org.apache.commons.math3.optim.univariate.UnivariateOptimizer)

Example 10 with Region

use of org.apache.commons.math3.geometry.partitioning.Region in project GDSC-SMLM by aherbert.

the class FIRE method run.

/*
	 * (non-Javadoc)
	 * 
	 * @see ij.plugin.PlugIn#run(java.lang.String)
	 */
public void run(String arg) {
    extraOptions = Utils.isExtraOptions();
    SMLMUsageTracker.recordPlugin(this.getClass(), arg);
    // Require some fit results and selected regions
    int size = MemoryPeakResults.countMemorySize();
    if (size == 0) {
        IJ.error(TITLE, "There are no fitting results in memory");
        return;
    }
    if ("q".equals(arg)) {
        TITLE += " Q estimation";
        runQEstimation();
        return;
    }
    IJ.showStatus(TITLE + " ...");
    if (!showInputDialog())
        return;
    MemoryPeakResults results = ResultsManager.loadInputResults(inputOption, false);
    if (results == null || results.size() == 0) {
        IJ.error(TITLE, "No results could be loaded");
        return;
    }
    MemoryPeakResults results2 = ResultsManager.loadInputResults(inputOption2, false);
    results = cropToRoi(results);
    if (results.size() < 2) {
        IJ.error(TITLE, "No results within the crop region");
        return;
    }
    if (results2 != null) {
        results2 = cropToRoi(results2);
        if (results2.size() < 2) {
            IJ.error(TITLE, "No results2 within the crop region");
            return;
        }
    }
    initialise(results, results2);
    if (!showDialog())
        return;
    long start = System.currentTimeMillis();
    // Compute FIRE
    String name = results.getName();
    double fourierImageScale = SCALE_VALUES[imageScaleIndex];
    int imageSize = IMAGE_SIZE_VALUES[imageSizeIndex];
    if (this.results2 != null) {
        name += " vs " + results2.getName();
        FireResult result = calculateFireNumber(fourierMethod, samplingMethod, thresholdMethod, fourierImageScale, imageSize);
        if (result != null) {
            logResult(name, result);
            if (showFRCCurve)
                showFrcCurve(name, result, thresholdMethod);
        }
    } else {
        FireResult result = null;
        int repeats = (randomSplit) ? Math.max(1, FIRE.repeats) : 1;
        if (repeats == 1) {
            result = calculateFireNumber(fourierMethod, samplingMethod, thresholdMethod, fourierImageScale, imageSize);
            if (result != null) {
                logResult(name, result);
                if (showFRCCurve)
                    showFrcCurve(name, result, thresholdMethod);
            }
        } else {
            // Multi-thread this ... 			
            int nThreads = Maths.min(repeats, getThreads());
            ExecutorService executor = Executors.newFixedThreadPool(nThreads);
            TurboList<Future<?>> futures = new TurboList<Future<?>>(repeats);
            TurboList<FIREWorker> workers = new TurboList<FIREWorker>(repeats);
            setProgress(repeats);
            IJ.showProgress(0);
            IJ.showStatus(TITLE + " computing ...");
            for (int i = 1; i <= repeats; i++) {
                FIREWorker w = new FIREWorker(i, fourierImageScale, imageSize);
                workers.add(w);
                futures.add(executor.submit(w));
            }
            // Wait for all to finish
            for (int t = futures.size(); t-- > 0; ) {
                try {
                    // The future .get() method will block until completed
                    futures.get(t).get();
                } catch (Exception e) {
                    // This should not happen. 
                    // Ignore it and allow processing to continue (the number of neighbour samples will just be smaller).  
                    e.printStackTrace();
                }
            }
            IJ.showProgress(1);
            executor.shutdown();
            // Show a combined FRC curve plot of all the smoothed curves if we have multiples.
            LUT valuesLUT = LUTHelper.createLUT(LutColour.FIRE_GLOW);
            @SuppressWarnings("unused") LUT // Black at max value
            noSmoothLUT = LUTHelper.createLUT(LutColour.GRAYS).createInvertedLut();
            LUTHelper.DefaultLUTMapper mapper = new LUTHelper.DefaultLUTMapper(0, repeats);
            FrcCurve curve = new FrcCurve();
            Statistics stats = new Statistics();
            WindowOrganiser wo = new WindowOrganiser();
            boolean oom = false;
            for (int i = 0; i < repeats; i++) {
                FIREWorker w = workers.get(i);
                if (w.oom)
                    oom = true;
                if (w.result == null)
                    continue;
                result = w.result;
                if (!Double.isNaN(result.fireNumber))
                    stats.add(result.fireNumber);
                if (showFRCCurveRepeats) {
                    // Output each FRC curve using a suffix.
                    logResult(w.name, result);
                    wo.add(Utils.display(w.plot.getTitle(), w.plot));
                }
                if (showFRCCurve) {
                    int index = mapper.map(i + 1);
                    //@formatter:off
                    curve.add(name, result, thresholdMethod, LUTHelper.getColour(valuesLUT, index), Color.blue, //LUTHelper.getColour(noSmoothLUT, index)
                    null);
                //@formatter:on
                }
            }
            if (result != null) {
                wo.cascade();
                double mean = stats.getMean();
                logResult(name, result, mean, stats);
                if (showFRCCurve) {
                    curve.addResolution(mean);
                    Plot2 plot = curve.getPlot();
                    Utils.display(plot.getTitle(), plot);
                }
            }
            if (oom) {
                //@formatter:off
                IJ.error(TITLE, "ERROR - Parallel computation out-of-memory.\n \n" + TextUtils.wrap("The number of results will be reduced. " + "Please reduce the size of the Fourier image " + "or change the number of threads " + "using the extra options (hold down the 'Shift' " + "key when running the plugin).", 80));
            //@formatter:on
            }
        }
        // Only do this once
        if (showFRCTimeEvolution && result != null && !Double.isNaN(result.fireNumber))
            showFrcTimeEvolution(name, result.fireNumber, thresholdMethod, nmPerPixel / result.getNmPerPixel(), imageSize);
    }
    IJ.showStatus(TITLE + " complete : " + Utils.timeToString(System.currentTimeMillis() - start));
}
Also used : TurboList(gdsc.core.utils.TurboList) LUTHelper(ij.process.LUTHelper) LUT(ij.process.LUT) WindowOrganiser(ij.plugin.WindowOrganiser) Plot2(ij.gui.Plot2) Statistics(gdsc.core.utils.Statistics) StoredDataStatistics(gdsc.core.utils.StoredDataStatistics) DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) WeightedObservedPoint(org.apache.commons.math3.fitting.WeightedObservedPoint) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) MemoryPeakResults(gdsc.smlm.results.MemoryPeakResults)

Aggregations

TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)7 PointValuePair (org.apache.commons.math3.optim.PointValuePair)6 WeightedObservedPoint (org.apache.commons.math3.fitting.WeightedObservedPoint)5 ImagePlus (ij.ImagePlus)4 ExtendedGenericDialog (ij.gui.ExtendedGenericDialog)4 LinkedList (java.util.LinkedList)4 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)4 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)4 InitialGuess (org.apache.commons.math3.optim.InitialGuess)4 MaxEval (org.apache.commons.math3.optim.MaxEval)4 OptimizationData (org.apache.commons.math3.optim.OptimizationData)4 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)4 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)4 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)4 ArrayList (java.util.ArrayList)3 CustomPowellOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer)3 Statistics (gdsc.core.utils.Statistics)2 StoredDataStatistics (gdsc.core.utils.StoredDataStatistics)2 MemoryPeakResults (gdsc.smlm.results.MemoryPeakResults)2 Plot2 (ij.gui.Plot2)2