Search in sources :

Example 1 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysis method createMixed.

/**
 * Creates the multivariate gaussian mixture as the best of many repeats of the expectation
 * maximisation algorithm.
 *
 * @param data the data
 * @param dimensions the dimensions
 * @param numComponents the number of components
 * @return the multivariate gaussian mixture expectation maximization
 */
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
    // Fit a mixed multivariate Gaussian with different repeats.
    final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
    final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
    final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
    if (settings.debug) {
        ImageJUtils.log("  Fitting %d components", numComponents);
    }
    final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
    final AtomicInteger failures = new AtomicInteger();
    for (int i = 0; i < settings.repeats; i++) {
        final double[] vector = sampler.sample();
        results.add(CompletableFuture.supplyAsync(() -> {
            final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
            try {
                // This may also throw the same exceptions due to inversion of the covariance matrix
                final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
                    double dot = 0;
                    for (int j = 0; j < dimensions; j++) {
                        dot += vector[j] * point[j];
                    }
                    return dot;
                });
                final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
                // Log the result. Note: The ImageJ log is synchronized.
                if (settings.debug) {
                    ImageJUtils.log("  Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
                }
                return result ? fitter : null;
            } catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
                failures.getAndIncrement();
                if (settings.debug) {
                    ImageJUtils.log("  Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
                }
            } finally {
                ticker.tick();
            }
            return null;
        }));
    }
    ImageJUtils.finished();
    if (failures.get() != 0 && settings.debug) {
        ImageJUtils.log("  %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
    }
    // Collect results and return the best model.
    return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
Also used : UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) Color(java.awt.Color) Arrays(java.util.Arrays) ByteProcessor(ij.process.ByteProcessor) Calibration(uk.ac.sussex.gdsc.smlm.data.config.CalibrationProtos.Calibration) IntUnaryOperator(java.util.function.IntUnaryOperator) HistogramPlotBuilder(uk.ac.sussex.gdsc.core.ij.HistogramPlot.HistogramPlotBuilder) IdFramePeakResultComparator(uk.ac.sussex.gdsc.smlm.results.sort.IdFramePeakResultComparator) UnaryOperator(java.util.function.UnaryOperator) RealVector(org.apache.commons.math3.linear.RealVector) Evaluation(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation) MultivariateJacobianFunction(org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction) VisibleForTesting(uk.ac.sussex.gdsc.core.data.VisibleForTesting) MemoryPeakResults(uk.ac.sussex.gdsc.smlm.results.MemoryPeakResults) NonPositiveDefiniteMatrixException(org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException) LeastSquaresFactory(org.apache.commons.math3.fitting.leastsquares.LeastSquaresFactory) RowSorter(javax.swing.RowSorter) JFrame(javax.swing.JFrame) LutHelper(uk.ac.sussex.gdsc.core.ij.process.LutHelper) KeyStroke(javax.swing.KeyStroke) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) DistanceUnit(uk.ac.sussex.gdsc.smlm.data.config.UnitProtos.DistanceUnit) KeyEvent(java.awt.event.KeyEvent) WindowAdapter(java.awt.event.WindowAdapter) TextUtils(uk.ac.sussex.gdsc.core.utils.TextUtils) Plot(ij.gui.Plot) TIntHashSet(gnu.trove.set.hash.TIntHashSet) ImagePlus(ij.ImagePlus) DefaultTableCellRenderer(javax.swing.table.DefaultTableCellRenderer) TDoubleArrayList(gnu.trove.list.array.TDoubleArrayList) SumOfSquaredDeviations(uk.ac.sussex.gdsc.core.math.SumOfSquaredDeviations) BasicStroke(java.awt.BasicStroke) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) FDistribution(org.apache.commons.math3.distribution.FDistribution) PlugIn(ij.plugin.PlugIn) ActionListener(java.awt.event.ActionListener) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) PolygonRoi(ij.gui.PolygonRoi) StoredData(uk.ac.sussex.gdsc.core.utils.StoredData) WindowManager(ij.WindowManager) PeakResult(uk.ac.sussex.gdsc.smlm.results.PeakResult) Supplier(java.util.function.Supplier) PointRoi(ij.gui.PointRoi) Trace(uk.ac.sussex.gdsc.smlm.results.Trace) MultiDialog(uk.ac.sussex.gdsc.core.ij.gui.MultiDialog) UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) GenericDialog(ij.gui.GenericDialog) AbstractTableModel(javax.swing.table.AbstractTableModel) SortUtils(uk.ac.sussex.gdsc.core.utils.SortUtils) Overlay(ij.gui.Overlay) IntFunction(java.util.function.IntFunction) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Pair(org.apache.commons.math3.util.Pair) Mean(uk.ac.sussex.gdsc.core.math.Mean) Window(java.awt.Window) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) AttributePeakResult(uk.ac.sussex.gdsc.smlm.results.AttributePeakResult) JScrollPane(javax.swing.JScrollPane) ConvergenceChecker(org.apache.commons.math3.optim.ConvergenceChecker) ListSelectionListener(javax.swing.event.ListSelectionListener) PeakResultStoreList(uk.ac.sussex.gdsc.smlm.results.PeakResultStoreList) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) DoubleEquality(uk.ac.sussex.gdsc.core.utils.DoubleEquality) TIntObjectHashMap(gnu.trove.map.hash.TIntObjectHashMap) TIntArrayList(gnu.trove.list.array.TIntArrayList) Mixers(uk.ac.sussex.gdsc.core.utils.rng.Mixers) TextWindow(ij.text.TextWindow) IntConsumer(java.util.function.IntConsumer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataException(uk.ac.sussex.gdsc.core.data.DataException) NonBlockingExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.NonBlockingExtendedGenericDialog) ScreenDimensionHelper(uk.ac.sussex.gdsc.core.ij.gui.ScreenDimensionHelper) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) CalibrationWriter(uk.ac.sussex.gdsc.smlm.data.config.CalibrationWriter) ListSelectionEvent(javax.swing.event.ListSelectionEvent) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) JMenuBar(javax.swing.JMenuBar) BufferedTextWindow(uk.ac.sussex.gdsc.core.ij.BufferedTextWindow) ExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.ExtendedGenericDialog) JMenu(javax.swing.JMenu) TIntIntHashMap(gnu.trove.map.hash.TIntIntHashMap) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization) WindowEvent(java.awt.event.WindowEvent) List(java.util.List) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) JTable(javax.swing.JTable) LUT(ij.process.LUT) TypeConverter(uk.ac.sussex.gdsc.core.data.utils.TypeConverter) Roi(ij.gui.Roi) IntStream(java.util.stream.IntStream) PrecisionResultProcedure(uk.ac.sussex.gdsc.smlm.results.procedures.PrecisionResultProcedure) ParameterValidator(org.apache.commons.math3.fitting.leastsquares.ParameterValidator) TDoubleList(gnu.trove.list.TDoubleList) ValidationUtils(uk.ac.sussex.gdsc.core.utils.ValidationUtils) CompletableFuture(java.util.concurrent.CompletableFuture) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) WindowOrganiser(uk.ac.sussex.gdsc.core.ij.plugin.WindowOrganiser) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) AtomicReference(java.util.concurrent.atomic.AtomicReference) SwingConstants(javax.swing.SwingConstants) DoubleUnaryOperator(java.util.function.DoubleUnaryOperator) JMenuItem(javax.swing.JMenuItem) Statistics(uk.ac.sussex.gdsc.core.utils.Statistics) DoubleData(uk.ac.sussex.gdsc.core.utils.DoubleData) TFloatArrayList(gnu.trove.list.array.TFloatArrayList) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) ChiSquaredDistributionTable(uk.ac.sussex.gdsc.smlm.function.ChiSquaredDistributionTable) LutColour(uk.ac.sussex.gdsc.core.ij.process.LutHelper.LutColour) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) ActionEvent(java.awt.event.ActionEvent) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) CalibrationReader(uk.ac.sussex.gdsc.smlm.data.config.CalibrationReader) Consumer(java.util.function.Consumer) ImageWindow(ij.gui.ImageWindow) SimpleRegression(org.apache.commons.math3.stat.regression.SimpleRegression) BinMethod(uk.ac.sussex.gdsc.core.ij.HistogramPlot.BinMethod) HistogramPlot(uk.ac.sussex.gdsc.core.ij.HistogramPlot) ImageJUtils(uk.ac.sussex.gdsc.core.ij.ImageJUtils) TableColumnAdjuster(uk.ac.sussex.gdsc.smlm.ij.gui.TableColumnAdjuster) IJ(ij.IJ) BitSet(java.util.BitSet) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) Collections(java.util.Collections) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) UniformRandomProviders(uk.ac.sussex.gdsc.core.utils.rng.UniformRandomProviders) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) CompletableFuture(java.util.concurrent.CompletableFuture) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization)

Example 2 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class PoissonFunctionTest method probabilityMatchesPoissonWithNoGain.

private static void probabilityMatchesPoissonWithNoGain(final double mu) {
    final double o = mu;
    final PoissonFunction f = new PoissonFunction(1.0);
    final PoissonDistribution pd = new PoissonDistribution(mu);
    final double p = 0;
    final int[] range = getRange(1, mu);
    final int min = range[0];
    final int max = range[1];
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-8, 0);
    for (int x = min; x <= max; x++) {
        final double v1 = f.likelihood(x, o);
        final double v2 = pd.probability(x);
        TestAssertions.assertTest(v1, v2, predicate, FunctionUtils.getSupplier("g=%f, mu=%f, x=%d", gain, mu, x));
    }
}
Also used : PoissonDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)

Example 3 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixedMultivariateGaussianDistribution.

@SeededTest
void canCreateMixedMultivariateGaussianDistribution(RandomSeed seed) {
    // Will be normalised
    final double[] weights = { 1, 1 };
    final double[][] means = new double[2][];
    final double[][][] covariances = new double[2][][];
    final double[][] data1 = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
    means[0] = getColumnMeans(data1);
    covariances[0] = getCovariance(data1);
    means[1] = getColumnMeans(data2);
    covariances[1] = getCovariance(data2);
    // Create components. This does not have to be zero based.
    final LocalList<double[]> list = new LocalList<>();
    list.addAll(Arrays.asList(data1));
    list.addAll(Arrays.asList(data2));
    final double[][] data = list.toArray(new double[0][]);
    final int[] components = { -1, -1, -1, 3, 3, 3 };
    // Randomise the data
    for (int n = 0; n < 3; n++) {
        final long start = n + seed.getSeedAsLong();
        // This relies on the shuffle being the same
        RandomUtils.shuffle(data, RngUtils.create(start));
        RandomUtils.shuffle(components, RngUtils.create(start));
        final MixtureMultivariateGaussianDistribution dist = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, components);
        Assertions.assertArrayEquals(new double[] { 0.5, 0.5 }, dist.getWeights());
        final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
        Assertions.assertEquals(weights.length, distributions.length);
        final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-8);
        for (int i = 0; i < means.length; i++) {
            TestAssertions.assertArrayTest(means[i], distributions[i].getMeans(), test);
            TestAssertions.assertArrayTest(covariances[i], distributions[i].getCovariances(), test);
        }
    }
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 4 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method fit.

/**
 * Fit a mixture model to the data supplied to the constructor.
 *
 * <p>The quality of the fit depends on the concavity of the data provided to the constructor and
 * the initial mixture provided to this function. If the data has many local optima, multiple runs
 * of the fitting function with different initial mixtures may be required to find the optimal
 * solution. If a SingularMatrixException is encountered, it is possible that another
 * initialisation would work.
 *
 * @param initialMixture model containing initial values of weights and multivariate normals
 * @param maxIterations maximum iterations allowed for fit
 * @param convergencePredicate convergence predicated used to test the logLikelihoods between
 *        successive iterations
 * @return true if converged within the threshold
 * @throws IllegalArgumentException if maxIterations is less than one or if initialMixture mean
 *         vector and data number of columns are not equal
 * @throws SingularMatrixException if any component's covariance matrix is singular during
 *         fitting.
 * @throws NonPositiveDefiniteMatrixException if any component's covariance matrix is not positive
 *         definite during fitting.
 */
public boolean fit(MixtureMultivariateGaussianDistribution initialMixture, int maxIterations, DoubleDoubleBiPredicate convergencePredicate) {
    ValidationUtils.checkStrictlyPositive(maxIterations, "maxIterations");
    ValidationUtils.checkNotNull(convergencePredicate, "convergencePredicate");
    ValidationUtils.checkNotNull(initialMixture, "initialMixture");
    final int n = data.length;
    final int k = initialMixture.weights.length;
    // Number of data columns.
    final int numCols = data[0].length;
    final int numMeanColumns = initialMixture.distributions[0].means.length;
    ValidationUtils.checkArgument(numCols == numMeanColumns, "Mixture model dimension mismatch with data columns: %d != %d", numCols, numMeanColumns);
    logLikelihood = -Double.MAX_VALUE;
    iterations = 0;
    // Initialize model to fit to initial mixture.
    fittedModel = initialMixture;
    while (iterations++ <= maxIterations) {
        final double previousLogLikelihood = logLikelihood;
        double sumLogLikelihood = 0;
        // Weight and distribution of each component
        final double[] weights = fittedModel.weights;
        final MultivariateGaussianDistribution[] mvns = fittedModel.distributions;
        // E-step: compute the data dependent parameters of the expectation function.
        // The percentage of row's total density between a row and a component
        final double[][] gamma = new double[n][k];
        // Sum of gamma for each component
        final double[] gammaSums = new double[k];
        // Sum of gamma times its row for each each component
        final double[][] gammaDataProdSums = new double[k][numCols];
        // Cache for the weight multiplied by the distribution density
        final double[] mvnDensity = new double[k];
        for (int i = 0; i < n; i++) {
            final double[] point = data[i];
            // Compute densities for each component and the row density
            double rowDensity = 0;
            for (int j = 0; j < k; j++) {
                final double d = weights[j] * mvns[j].density(point);
                mvnDensity[j] = d;
                rowDensity += d;
            }
            sumLogLikelihood += Math.log(rowDensity);
            for (int j = 0; j < k; j++) {
                gamma[i][j] = mvnDensity[j] / rowDensity;
                gammaSums[j] += gamma[i][j];
                for (int col = 0; col < numCols; col++) {
                    gammaDataProdSums[j][col] += gamma[i][j] * point[col];
                }
            }
        }
        logLikelihood = sumLogLikelihood;
        // M-step: compute the new parameters based on the expectation function.
        final double[] newWeights = new double[k];
        final double[][] newMeans = new double[k][numCols];
        for (int j = 0; j < k; j++) {
            newWeights[j] = gammaSums[j] / n;
            for (int col = 0; col < numCols; col++) {
                newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
            }
        }
        // Compute new covariance matrices.
        // These are symmetric so we compute the triangular half.
        final double[][][] newCovMats = new double[k][numCols][numCols];
        final double[] vec = new double[numCols];
        for (int i = 0; i < n; i++) {
            final double[] point = data[i];
            for (int j = 0; j < k; j++) {
                subtract(point, newMeans[j], vec);
                final double g = gamma[i][j];
                // covariance = vec * vecT
                // covariance[ii][jj] = vec[ii] * vec[jj] * gamma[i][j]
                final double[][] covar = newCovMats[j];
                for (int ii = 0; ii < numCols; ii++) {
                    // pre-compute
                    final double vig = vec[ii] * g;
                    final double[] covari = covar[ii];
                    for (int jj = 0; jj <= ii; jj++) {
                        covari[jj] += vig * vec[jj];
                    }
                }
            }
        }
        // Converting to arrays for use by fitted model
        final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[k];
        for (int j = 0; j < k; j++) {
            // Make symmetric and normalise by gamma sum
            final double norm = 1.0 / gammaSums[j];
            final double[][] covar = newCovMats[j];
            for (int ii = 0; ii < numCols; ii++) {
                // diagonal
                covar[ii][ii] *= norm;
                // elements
                for (int jj = 0; jj < ii; jj++) {
                    final double tmp = covar[ii][jj] * norm;
                    covar[ii][jj] = tmp;
                    covar[jj][ii] = tmp;
                }
            }
            distributions[j] = new MultivariateGaussianDistribution(newMeans[j], covar);
        }
        // Update current model
        fittedModel = MixtureMultivariateGaussianDistribution.create(newWeights, distributions);
        // Check convergence
        if (convergencePredicate.test(previousLogLikelihood, logLikelihood)) {
            return true;
        }
    }
    // No convergence
    return false;
}
Also used : MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 5 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateUnmixedMultivariateGaussianDistribution.

@Test
void canCreateUnmixedMultivariateGaussianDistribution() {
    final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[] means = getColumnMeans(data);
    final double[][] covariances = getCovariance(data);
    final MultivariateGaussianDistribution exp = MultivariateGaussianDistribution.create(means, covariances);
    final MultivariateGaussianDistribution obs = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-6);
    TestAssertions.assertArrayTest(exp.getMeans(), obs.getMeans(), test);
    TestAssertions.assertArrayTest(exp.getCovariances(), obs.getCovariances(), test);
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Aggregations

MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)6 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)5 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)5 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)3 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Pair (org.apache.commons.math3.util.Pair)2 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)2 TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)1 TIntHashSet (gnu.trove.set.hash.TIntHashSet)1 IJ (ij.IJ)1 ImagePlus (ij.ImagePlus)1 WindowManager (ij.WindowManager)1 GenericDialog (ij.gui.GenericDialog)1