Search in sources :

Example 1 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysis method createMixed.

/**
 * Creates the multivariate gaussian mixture as the best of many repeats of the expectation
 * maximisation algorithm.
 *
 * @param data the data
 * @param dimensions the dimensions
 * @param numComponents the number of components
 * @return the multivariate gaussian mixture expectation maximization
 */
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
    // Fit a mixed multivariate Gaussian with different repeats.
    final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
    final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
    final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
    if (settings.debug) {
        ImageJUtils.log("  Fitting %d components", numComponents);
    }
    final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
    final AtomicInteger failures = new AtomicInteger();
    for (int i = 0; i < settings.repeats; i++) {
        final double[] vector = sampler.sample();
        results.add(CompletableFuture.supplyAsync(() -> {
            final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
            try {
                // This may also throw the same exceptions due to inversion of the covariance matrix
                final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
                    double dot = 0;
                    for (int j = 0; j < dimensions; j++) {
                        dot += vector[j] * point[j];
                    }
                    return dot;
                });
                final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
                // Log the result. Note: The ImageJ log is synchronized.
                if (settings.debug) {
                    ImageJUtils.log("  Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
                }
                return result ? fitter : null;
            } catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
                failures.getAndIncrement();
                if (settings.debug) {
                    ImageJUtils.log("  Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
                }
            } finally {
                ticker.tick();
            }
            return null;
        }));
    }
    ImageJUtils.finished();
    if (failures.get() != 0 && settings.debug) {
        ImageJUtils.log("  %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
    }
    // Collect results and return the best model.
    return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
Also used : UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) Color(java.awt.Color) Arrays(java.util.Arrays) ByteProcessor(ij.process.ByteProcessor) Calibration(uk.ac.sussex.gdsc.smlm.data.config.CalibrationProtos.Calibration) IntUnaryOperator(java.util.function.IntUnaryOperator) HistogramPlotBuilder(uk.ac.sussex.gdsc.core.ij.HistogramPlot.HistogramPlotBuilder) IdFramePeakResultComparator(uk.ac.sussex.gdsc.smlm.results.sort.IdFramePeakResultComparator) UnaryOperator(java.util.function.UnaryOperator) RealVector(org.apache.commons.math3.linear.RealVector) Evaluation(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation) MultivariateJacobianFunction(org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction) VisibleForTesting(uk.ac.sussex.gdsc.core.data.VisibleForTesting) MemoryPeakResults(uk.ac.sussex.gdsc.smlm.results.MemoryPeakResults) NonPositiveDefiniteMatrixException(org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException) LeastSquaresFactory(org.apache.commons.math3.fitting.leastsquares.LeastSquaresFactory) RowSorter(javax.swing.RowSorter) JFrame(javax.swing.JFrame) LutHelper(uk.ac.sussex.gdsc.core.ij.process.LutHelper) KeyStroke(javax.swing.KeyStroke) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) DistanceUnit(uk.ac.sussex.gdsc.smlm.data.config.UnitProtos.DistanceUnit) KeyEvent(java.awt.event.KeyEvent) WindowAdapter(java.awt.event.WindowAdapter) TextUtils(uk.ac.sussex.gdsc.core.utils.TextUtils) Plot(ij.gui.Plot) TIntHashSet(gnu.trove.set.hash.TIntHashSet) ImagePlus(ij.ImagePlus) DefaultTableCellRenderer(javax.swing.table.DefaultTableCellRenderer) TDoubleArrayList(gnu.trove.list.array.TDoubleArrayList) SumOfSquaredDeviations(uk.ac.sussex.gdsc.core.math.SumOfSquaredDeviations) BasicStroke(java.awt.BasicStroke) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) FDistribution(org.apache.commons.math3.distribution.FDistribution) PlugIn(ij.plugin.PlugIn) ActionListener(java.awt.event.ActionListener) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) PolygonRoi(ij.gui.PolygonRoi) StoredData(uk.ac.sussex.gdsc.core.utils.StoredData) WindowManager(ij.WindowManager) PeakResult(uk.ac.sussex.gdsc.smlm.results.PeakResult) Supplier(java.util.function.Supplier) PointRoi(ij.gui.PointRoi) Trace(uk.ac.sussex.gdsc.smlm.results.Trace) MultiDialog(uk.ac.sussex.gdsc.core.ij.gui.MultiDialog) UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) GenericDialog(ij.gui.GenericDialog) AbstractTableModel(javax.swing.table.AbstractTableModel) SortUtils(uk.ac.sussex.gdsc.core.utils.SortUtils) Overlay(ij.gui.Overlay) IntFunction(java.util.function.IntFunction) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Pair(org.apache.commons.math3.util.Pair) Mean(uk.ac.sussex.gdsc.core.math.Mean) Window(java.awt.Window) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) AttributePeakResult(uk.ac.sussex.gdsc.smlm.results.AttributePeakResult) JScrollPane(javax.swing.JScrollPane) ConvergenceChecker(org.apache.commons.math3.optim.ConvergenceChecker) ListSelectionListener(javax.swing.event.ListSelectionListener) PeakResultStoreList(uk.ac.sussex.gdsc.smlm.results.PeakResultStoreList) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) DoubleEquality(uk.ac.sussex.gdsc.core.utils.DoubleEquality) TIntObjectHashMap(gnu.trove.map.hash.TIntObjectHashMap) TIntArrayList(gnu.trove.list.array.TIntArrayList) Mixers(uk.ac.sussex.gdsc.core.utils.rng.Mixers) TextWindow(ij.text.TextWindow) IntConsumer(java.util.function.IntConsumer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataException(uk.ac.sussex.gdsc.core.data.DataException) NonBlockingExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.NonBlockingExtendedGenericDialog) ScreenDimensionHelper(uk.ac.sussex.gdsc.core.ij.gui.ScreenDimensionHelper) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) CalibrationWriter(uk.ac.sussex.gdsc.smlm.data.config.CalibrationWriter) ListSelectionEvent(javax.swing.event.ListSelectionEvent) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) JMenuBar(javax.swing.JMenuBar) BufferedTextWindow(uk.ac.sussex.gdsc.core.ij.BufferedTextWindow) ExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.ExtendedGenericDialog) JMenu(javax.swing.JMenu) TIntIntHashMap(gnu.trove.map.hash.TIntIntHashMap) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization) WindowEvent(java.awt.event.WindowEvent) List(java.util.List) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) JTable(javax.swing.JTable) LUT(ij.process.LUT) TypeConverter(uk.ac.sussex.gdsc.core.data.utils.TypeConverter) Roi(ij.gui.Roi) IntStream(java.util.stream.IntStream) PrecisionResultProcedure(uk.ac.sussex.gdsc.smlm.results.procedures.PrecisionResultProcedure) ParameterValidator(org.apache.commons.math3.fitting.leastsquares.ParameterValidator) TDoubleList(gnu.trove.list.TDoubleList) ValidationUtils(uk.ac.sussex.gdsc.core.utils.ValidationUtils) CompletableFuture(java.util.concurrent.CompletableFuture) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) WindowOrganiser(uk.ac.sussex.gdsc.core.ij.plugin.WindowOrganiser) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) AtomicReference(java.util.concurrent.atomic.AtomicReference) SwingConstants(javax.swing.SwingConstants) DoubleUnaryOperator(java.util.function.DoubleUnaryOperator) JMenuItem(javax.swing.JMenuItem) Statistics(uk.ac.sussex.gdsc.core.utils.Statistics) DoubleData(uk.ac.sussex.gdsc.core.utils.DoubleData) TFloatArrayList(gnu.trove.list.array.TFloatArrayList) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) ChiSquaredDistributionTable(uk.ac.sussex.gdsc.smlm.function.ChiSquaredDistributionTable) LutColour(uk.ac.sussex.gdsc.core.ij.process.LutHelper.LutColour) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) ActionEvent(java.awt.event.ActionEvent) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) CalibrationReader(uk.ac.sussex.gdsc.smlm.data.config.CalibrationReader) Consumer(java.util.function.Consumer) ImageWindow(ij.gui.ImageWindow) SimpleRegression(org.apache.commons.math3.stat.regression.SimpleRegression) BinMethod(uk.ac.sussex.gdsc.core.ij.HistogramPlot.BinMethod) HistogramPlot(uk.ac.sussex.gdsc.core.ij.HistogramPlot) ImageJUtils(uk.ac.sussex.gdsc.core.ij.ImageJUtils) TableColumnAdjuster(uk.ac.sussex.gdsc.smlm.ij.gui.TableColumnAdjuster) IJ(ij.IJ) BitSet(java.util.BitSet) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) Collections(java.util.Collections) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) UniformRandomProviders(uk.ac.sussex.gdsc.core.utils.rng.UniformRandomProviders) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) CompletableFuture(java.util.concurrent.CompletableFuture) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization)

Example 2 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysis method createModelTable.

/**
 * Creates a table to show the final model. This uses the assignments to create a mixture model
 * from the original data.
 *
 * @param data the data
 * @param weights the weights for each component
 * @param component the component
 */
private static void createModelTable(double[][] data, double[] weights, int[] component) {
    final MixtureMultivariateGaussianDistribution model = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, component);
    final MultivariateGaussianDistribution[] distributions = model.getDistributions();
    // Get the fraction of each component
    final int[] count = new int[MathUtils.max(component) + 1];
    Arrays.stream(component).forEach(c -> count[c]++);
    try (BufferedTextWindow tw = new BufferedTextWindow(ImageJUtils.refresh(modelTableRef, () -> new TextWindow("Track Population Model", createHeader(), "", 800, 300)))) {
        final StringBuilder sb = new StringBuilder();
        for (int i = 0; i < weights.length; i++) {
            sb.setLength(0);
            sb.append(i).append('\t');
            sb.append(MathUtils.rounded((double) count[i] / component.length)).append('\t');
            sb.append(MathUtils.rounded(weights[i]));
            final double[] means = distributions[i].getMeans();
            final double[] sd = distributions[i].getStandardDeviations();
            for (int j = 0; j < means.length; j++) {
                sb.append('\t').append(MathUtils.rounded(means[j])).append('\t').append(MathUtils.rounded(sd[j]));
            }
            tw.append(sb.toString());
        }
    }
}
Also used : MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) BufferedTextWindow(uk.ac.sussex.gdsc.core.ij.BufferedTextWindow) TextWindow(ij.text.TextWindow) BufferedTextWindow(uk.ac.sussex.gdsc.core.ij.BufferedTextWindow)

Example 3 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method testFitThrows.

@Test
void testFitThrows() {
    // This does not matter for the initial checks
    final MixtureMultivariateGaussianDistribution initialMixture = null;
    final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(new double[2][3]);
    // Test initial model is null and the likelihood is zero
    Assertions.assertEquals(0, fitter.getLogLikelihood());
    Assertions.assertEquals(0, fitter.getIterations());
    Assertions.assertNull(fitter.getFittedModel());
    // Valid parameters
    final int maxIterations = 10;
    // Not positive iterations
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        fitter.fit(initialMixture, 0, DEFAULT_CONVERGENCE_CHECKER);
    });
    // Not convergence checker
    Assertions.assertThrows(NullPointerException.class, () -> {
        fitter.fit(initialMixture, maxIterations, null);
    });
    // Null mixture
    Assertions.assertThrows(NullPointerException.class, () -> {
        fitter.fit(initialMixture, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
    });
    // Incorrect data dimensions. Create a 50-50 mixture of 2D Gaussians
    final MixtureMultivariateGaussianDistribution initialMixture2 = new MixtureMultivariateGaussianDistribution(new double[] { 0.5, 0.5 }, new MultivariateGaussianDistribution[] { new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }), new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }) });
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        fitter.fit(initialMixture2, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
    });
}
Also used : MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Example 4 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixedMultivariateGaussianDistribution.

@SeededTest
void canCreateMixedMultivariateGaussianDistribution(RandomSeed seed) {
    // Will be normalised
    final double[] weights = { 1, 1 };
    final double[][] means = new double[2][];
    final double[][][] covariances = new double[2][][];
    final double[][] data1 = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
    means[0] = getColumnMeans(data1);
    covariances[0] = getCovariance(data1);
    means[1] = getColumnMeans(data2);
    covariances[1] = getCovariance(data2);
    // Create components. This does not have to be zero based.
    final LocalList<double[]> list = new LocalList<>();
    list.addAll(Arrays.asList(data1));
    list.addAll(Arrays.asList(data2));
    final double[][] data = list.toArray(new double[0][]);
    final int[] components = { -1, -1, -1, 3, 3, 3 };
    // Randomise the data
    for (int n = 0; n < 3; n++) {
        final long start = n + seed.getSeedAsLong();
        // This relies on the shuffle being the same
        RandomUtils.shuffle(data, RngUtils.create(start));
        RandomUtils.shuffle(components, RngUtils.create(start));
        final MixtureMultivariateGaussianDistribution dist = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, components);
        Assertions.assertArrayEquals(new double[] { 0.5, 0.5 }, dist.getWeights());
        final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
        Assertions.assertEquals(weights.length, distributions.length);
        final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-8);
        for (int i = 0; i < means.length; i++) {
            TestAssertions.assertArrayTest(means[i], distributions[i].getMeans(), test);
            TestAssertions.assertArrayTest(covariances[i], distributions[i].getCovariances(), test);
        }
    }
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 5 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method testCreateMixtureMultivariateGaussianDistributionThrows.

@Test
void testCreateMixtureMultivariateGaussianDistributionThrows() {
    // Will be normalised
    final double[] weights = { 1, 3 };
    final double[][] means = new double[2][];
    final double[][][] covariances = new double[2][][];
    final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
    means[0] = getColumnMeans(data);
    covariances[0] = getCovariance(data);
    means[1] = getColumnMeans(data2);
    covariances[1] = getCovariance(data2);
    // Test with means and covariances
    // Non positive weights
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        MixtureMultivariateGaussianDistribution.create(new double[] { -1, 1 }, means, covariances);
    });
    // Weights sum is not finite
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        MixtureMultivariateGaussianDistribution.create(new double[] { Double.MAX_VALUE, Double.MAX_VALUE }, means, covariances);
    });
    // Incorrect size mean
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        final double[][] means2 = means.clone();
        means2[0] = Arrays.copyOf(means2[0], means2[0].length + 1);
        MixtureMultivariateGaussianDistribution.create(weights, means2, covariances);
    });
    // Bad covariance matrix
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        final double[][][] covariances2 = covariances.clone();
        covariances2[0] = new double[3][2];
        MixtureMultivariateGaussianDistribution.create(weights, means, covariances2);
    });
    // Test with the weights and distributions.
    // Create a valid mixture to use the distributions.
    final MixtureMultivariateGaussianDistribution dist = MixtureMultivariateGaussianDistribution.create(weights, means, covariances);
    // Weights sum is not finite
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        MixtureMultivariateGaussianDistribution.create(new double[] { Double.MAX_VALUE, Double.MAX_VALUE }, dist.getDistributions());
    });
    // Weights and distributions length mismatch
    Assertions.assertThrows(IllegalArgumentException.class, () -> {
        MixtureMultivariateGaussianDistribution.create(new double[] { 1 }, dist.getDistributions());
    });
}
Also used : MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Aggregations

MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)10 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)9 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)6 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)4 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)3 Test (org.junit.jupiter.api.Test)3 Plot (ij.gui.Plot)2 LUT (ij.process.LUT)2 TextWindow (ij.text.TextWindow)2 Color (java.awt.Color)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Pair (org.apache.commons.math3.util.Pair)2 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 BufferedTextWindow (uk.ac.sussex.gdsc.core.ij.BufferedTextWindow)2 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)2 TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1