Search in sources :

Example 1 with NonPositiveDefiniteMatrixException

use of org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysis method createMixed.

/**
 * Creates the multivariate gaussian mixture as the best of many repeats of the expectation
 * maximisation algorithm.
 *
 * @param data the data
 * @param dimensions the dimensions
 * @param numComponents the number of components
 * @return the multivariate gaussian mixture expectation maximization
 */
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
    // Fit a mixed multivariate Gaussian with different repeats.
    final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
    final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
    final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
    if (settings.debug) {
        ImageJUtils.log("  Fitting %d components", numComponents);
    }
    final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
    final AtomicInteger failures = new AtomicInteger();
    for (int i = 0; i < settings.repeats; i++) {
        final double[] vector = sampler.sample();
        results.add(CompletableFuture.supplyAsync(() -> {
            final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
            try {
                // This may also throw the same exceptions due to inversion of the covariance matrix
                final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
                    double dot = 0;
                    for (int j = 0; j < dimensions; j++) {
                        dot += vector[j] * point[j];
                    }
                    return dot;
                });
                final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
                // Log the result. Note: The ImageJ log is synchronized.
                if (settings.debug) {
                    ImageJUtils.log("  Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
                }
                return result ? fitter : null;
            } catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
                failures.getAndIncrement();
                if (settings.debug) {
                    ImageJUtils.log("  Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
                }
            } finally {
                ticker.tick();
            }
            return null;
        }));
    }
    ImageJUtils.finished();
    if (failures.get() != 0 && settings.debug) {
        ImageJUtils.log("  %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
    }
    // Collect results and return the best model.
    return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
Also used : UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) Color(java.awt.Color) Arrays(java.util.Arrays) ByteProcessor(ij.process.ByteProcessor) Calibration(uk.ac.sussex.gdsc.smlm.data.config.CalibrationProtos.Calibration) IntUnaryOperator(java.util.function.IntUnaryOperator) HistogramPlotBuilder(uk.ac.sussex.gdsc.core.ij.HistogramPlot.HistogramPlotBuilder) IdFramePeakResultComparator(uk.ac.sussex.gdsc.smlm.results.sort.IdFramePeakResultComparator) UnaryOperator(java.util.function.UnaryOperator) RealVector(org.apache.commons.math3.linear.RealVector) Evaluation(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation) MultivariateJacobianFunction(org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction) VisibleForTesting(uk.ac.sussex.gdsc.core.data.VisibleForTesting) MemoryPeakResults(uk.ac.sussex.gdsc.smlm.results.MemoryPeakResults) NonPositiveDefiniteMatrixException(org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException) LeastSquaresFactory(org.apache.commons.math3.fitting.leastsquares.LeastSquaresFactory) RowSorter(javax.swing.RowSorter) JFrame(javax.swing.JFrame) LutHelper(uk.ac.sussex.gdsc.core.ij.process.LutHelper) KeyStroke(javax.swing.KeyStroke) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) DistanceUnit(uk.ac.sussex.gdsc.smlm.data.config.UnitProtos.DistanceUnit) KeyEvent(java.awt.event.KeyEvent) WindowAdapter(java.awt.event.WindowAdapter) TextUtils(uk.ac.sussex.gdsc.core.utils.TextUtils) Plot(ij.gui.Plot) TIntHashSet(gnu.trove.set.hash.TIntHashSet) ImagePlus(ij.ImagePlus) DefaultTableCellRenderer(javax.swing.table.DefaultTableCellRenderer) TDoubleArrayList(gnu.trove.list.array.TDoubleArrayList) SumOfSquaredDeviations(uk.ac.sussex.gdsc.core.math.SumOfSquaredDeviations) BasicStroke(java.awt.BasicStroke) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) FDistribution(org.apache.commons.math3.distribution.FDistribution) PlugIn(ij.plugin.PlugIn) ActionListener(java.awt.event.ActionListener) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) PolygonRoi(ij.gui.PolygonRoi) StoredData(uk.ac.sussex.gdsc.core.utils.StoredData) WindowManager(ij.WindowManager) PeakResult(uk.ac.sussex.gdsc.smlm.results.PeakResult) Supplier(java.util.function.Supplier) PointRoi(ij.gui.PointRoi) Trace(uk.ac.sussex.gdsc.smlm.results.Trace) MultiDialog(uk.ac.sussex.gdsc.core.ij.gui.MultiDialog) UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) GenericDialog(ij.gui.GenericDialog) AbstractTableModel(javax.swing.table.AbstractTableModel) SortUtils(uk.ac.sussex.gdsc.core.utils.SortUtils) Overlay(ij.gui.Overlay) IntFunction(java.util.function.IntFunction) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Pair(org.apache.commons.math3.util.Pair) Mean(uk.ac.sussex.gdsc.core.math.Mean) Window(java.awt.Window) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) AttributePeakResult(uk.ac.sussex.gdsc.smlm.results.AttributePeakResult) JScrollPane(javax.swing.JScrollPane) ConvergenceChecker(org.apache.commons.math3.optim.ConvergenceChecker) ListSelectionListener(javax.swing.event.ListSelectionListener) PeakResultStoreList(uk.ac.sussex.gdsc.smlm.results.PeakResultStoreList) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) DoubleEquality(uk.ac.sussex.gdsc.core.utils.DoubleEquality) TIntObjectHashMap(gnu.trove.map.hash.TIntObjectHashMap) TIntArrayList(gnu.trove.list.array.TIntArrayList) Mixers(uk.ac.sussex.gdsc.core.utils.rng.Mixers) TextWindow(ij.text.TextWindow) IntConsumer(java.util.function.IntConsumer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataException(uk.ac.sussex.gdsc.core.data.DataException) NonBlockingExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.NonBlockingExtendedGenericDialog) ScreenDimensionHelper(uk.ac.sussex.gdsc.core.ij.gui.ScreenDimensionHelper) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) CalibrationWriter(uk.ac.sussex.gdsc.smlm.data.config.CalibrationWriter) ListSelectionEvent(javax.swing.event.ListSelectionEvent) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) JMenuBar(javax.swing.JMenuBar) BufferedTextWindow(uk.ac.sussex.gdsc.core.ij.BufferedTextWindow) ExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.ExtendedGenericDialog) JMenu(javax.swing.JMenu) TIntIntHashMap(gnu.trove.map.hash.TIntIntHashMap) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization) WindowEvent(java.awt.event.WindowEvent) List(java.util.List) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) JTable(javax.swing.JTable) LUT(ij.process.LUT) TypeConverter(uk.ac.sussex.gdsc.core.data.utils.TypeConverter) Roi(ij.gui.Roi) IntStream(java.util.stream.IntStream) PrecisionResultProcedure(uk.ac.sussex.gdsc.smlm.results.procedures.PrecisionResultProcedure) ParameterValidator(org.apache.commons.math3.fitting.leastsquares.ParameterValidator) TDoubleList(gnu.trove.list.TDoubleList) ValidationUtils(uk.ac.sussex.gdsc.core.utils.ValidationUtils) CompletableFuture(java.util.concurrent.CompletableFuture) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) WindowOrganiser(uk.ac.sussex.gdsc.core.ij.plugin.WindowOrganiser) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) AtomicReference(java.util.concurrent.atomic.AtomicReference) SwingConstants(javax.swing.SwingConstants) DoubleUnaryOperator(java.util.function.DoubleUnaryOperator) JMenuItem(javax.swing.JMenuItem) Statistics(uk.ac.sussex.gdsc.core.utils.Statistics) DoubleData(uk.ac.sussex.gdsc.core.utils.DoubleData) TFloatArrayList(gnu.trove.list.array.TFloatArrayList) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) ChiSquaredDistributionTable(uk.ac.sussex.gdsc.smlm.function.ChiSquaredDistributionTable) LutColour(uk.ac.sussex.gdsc.core.ij.process.LutHelper.LutColour) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) ActionEvent(java.awt.event.ActionEvent) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) CalibrationReader(uk.ac.sussex.gdsc.smlm.data.config.CalibrationReader) Consumer(java.util.function.Consumer) ImageWindow(ij.gui.ImageWindow) SimpleRegression(org.apache.commons.math3.stat.regression.SimpleRegression) BinMethod(uk.ac.sussex.gdsc.core.ij.HistogramPlot.BinMethod) HistogramPlot(uk.ac.sussex.gdsc.core.ij.HistogramPlot) ImageJUtils(uk.ac.sussex.gdsc.core.ij.ImageJUtils) TableColumnAdjuster(uk.ac.sussex.gdsc.smlm.ij.gui.TableColumnAdjuster) IJ(ij.IJ) BitSet(java.util.BitSet) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) Collections(java.util.Collections) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) UniformRandomProviders(uk.ac.sussex.gdsc.core.utils.rng.UniformRandomProviders) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) CompletableFuture(java.util.concurrent.CompletableFuture) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization)

Aggregations

TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)1 TIntHashSet (gnu.trove.set.hash.TIntHashSet)1 IJ (ij.IJ)1 ImagePlus (ij.ImagePlus)1 WindowManager (ij.WindowManager)1 GenericDialog (ij.gui.GenericDialog)1 ImageWindow (ij.gui.ImageWindow)1 Overlay (ij.gui.Overlay)1 Plot (ij.gui.Plot)1 PointRoi (ij.gui.PointRoi)1 PolygonRoi (ij.gui.PolygonRoi)1 Roi (ij.gui.Roi)1 PlugIn (ij.plugin.PlugIn)1 ByteProcessor (ij.process.ByteProcessor)1 LUT (ij.process.LUT)1