Search in sources :

Example 6 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canEstimateInitialMixture.

@SeededTest
void canEstimateInitialMixture(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution model1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MixtureMultivariateNormalDistribution model2 = MultivariateNormalMixtureExpectationMaximization.estimate(data, n);
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            // Must be binary equal for estimated model
            Assertions.assertEquals(comp.get(i).getFirst(), weights[i], "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 7 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.

@SeededTest
void canFit(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    final int sampleSize = 1000;
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
        Assertions.assertTrue(fitter1.fit(initialModel1));
        final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
        fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
        final double ll1 = fitter1.getLogLikelihood() / sampleSize;
        Assertions.assertNotEquals(0, ll1);
        final double ll2 = fitter2.getLogLikelihood();
        TestAssertions.assertTest(ll2, ll1, test);
        final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
        Assertions.assertNotNull(model1);
        final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
        // Check fitted models are the same
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
        final int iterations = fitter1.getIterations();
        Assertions.assertNotEquals(0, iterations);
        // Test without convergence
        if (iterations > 2) {
            Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)6 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)5 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)5 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)3 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Pair (org.apache.commons.math3.util.Pair)2 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)2 TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)1 TIntHashSet (gnu.trove.set.hash.TIntHashSet)1 IJ (ij.IJ)1 ImagePlus (ij.ImagePlus)1 WindowManager (ij.WindowManager)1 GenericDialog (ij.gui.GenericDialog)1