Search in sources :

Example 6 with MultivariateGaussianMixtureExpectationMaximization

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.

@SeededTest
void canFit(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    final int sampleSize = 1000;
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
        Assertions.assertTrue(fitter1.fit(initialModel1));
        final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
        fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
        final double ll1 = fitter1.getLogLikelihood() / sampleSize;
        Assertions.assertNotEquals(0, ll1);
        final double ll2 = fitter2.getLogLikelihood();
        TestAssertions.assertTest(ll2, ll1, test);
        final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
        Assertions.assertNotNull(model1);
        final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
        // Check fitted models are the same
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
        final int iterations = fitter1.getIterations();
        Assertions.assertNotEquals(0, iterations);
        // Test without convergence
        if (iterations > 2) {
            Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)6 MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)5 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)4 Pair (org.apache.commons.math3.util.Pair)4 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)4 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)4 Plot (ij.gui.Plot)3 LUT (ij.process.LUT)3 Color (java.awt.Color)3 TDoubleList (gnu.trove.list.TDoubleList)2 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)2 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)2 TIntArrayList (gnu.trove.list.array.TIntArrayList)2 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)2 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)2 TIntHashSet (gnu.trove.set.hash.TIntHashSet)2 IJ (ij.IJ)2 ImagePlus (ij.ImagePlus)2 WindowManager (ij.WindowManager)2 GenericDialog (ij.gui.GenericDialog)2