Search in sources :

Example 11 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.

@SeededTest
void canFit(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    final int sampleSize = 1000;
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
        Assertions.assertTrue(fitter1.fit(initialModel1));
        final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
        fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
        final double ll1 = fitter1.getLogLikelihood() / sampleSize;
        Assertions.assertNotEquals(0, ll1);
        final double ll2 = fitter2.getLogLikelihood();
        TestAssertions.assertTest(ll2, ll1, test);
        final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
        Assertions.assertNotNull(model1);
        final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
        // Check fitted models are the same
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
        final int iterations = fitter1.getIterations();
        Assertions.assertNotEquals(0, iterations);
        // Test without convergence
        if (iterations > 2) {
            Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 12 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixtureMultivariateGaussianDistribution.

@Test
void canCreateMixtureMultivariateGaussianDistribution() {
    // Will be normalised
    final double[] weights = { 1, 3 };
    final double[][] means = new double[2][];
    final double[][][] covariances = new double[2][][];
    final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
    means[0] = getColumnMeans(data);
    covariances[0] = getCovariance(data);
    means[1] = getColumnMeans(data2);
    covariances[1] = getCovariance(data2);
    final MixtureMultivariateGaussianDistribution dist = MixtureMultivariateGaussianDistribution.create(weights, means, covariances);
    Assertions.assertArrayEquals(new double[] { 0.25, 0.75 }, dist.getWeights());
    final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
    Assertions.assertEquals(weights.length, distributions.length);
    for (int i = 0; i < means.length; i++) {
        Assertions.assertArrayEquals(means[i], distributions[i].getMeans());
        Assertions.assertArrayEquals(covariances[i], distributions[i].getCovariances());
    }
    // Test against Apache commons
    final MixtureMultivariateNormalDistribution expDist = new MixtureMultivariateNormalDistribution(weights, means, covariances);
    for (final double[] x : data) {
        Assertions.assertEquals(expDist.density(x), dist.density(x), 1e-10);
    }
    // Test the package private create method normalises the weights
    Assertions.assertArrayEquals(new double[] { 1, 3 }, weights);
    final MixtureMultivariateGaussianDistribution dist2 = MixtureMultivariateGaussianDistribution.create(weights, distributions);
    // Stored by reference
    Assertions.assertArrayEquals(weights, dist2.getWeights());
    // Normalised in-place
    Assertions.assertArrayEquals(new double[] { 0.25, 0.75 }, weights);
}
Also used : MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Aggregations

MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)10 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)9 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)6 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)4 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)3 Test (org.junit.jupiter.api.Test)3 Plot (ij.gui.Plot)2 LUT (ij.process.LUT)2 TextWindow (ij.text.TextWindow)2 Color (java.awt.Color)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Pair (org.apache.commons.math3.util.Pair)2 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 BufferedTextWindow (uk.ac.sussex.gdsc.core.ij.BufferedTextWindow)2 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)2 TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1