Search in sources :

Example 6 with MultivariateNormalDistribution

use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.

@SeededTest
void canFit(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    final int sampleSize = 1000;
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
        Assertions.assertTrue(fitter1.fit(initialModel1));
        final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
        fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
        final double ll1 = fitter1.getLogLikelihood() / sampleSize;
        Assertions.assertNotEquals(0, ll1);
        final double ll2 = fitter2.getLogLikelihood();
        TestAssertions.assertTest(ll2, ll1, test);
        final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
        Assertions.assertNotNull(model1);
        final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
        // Check fitted models are the same
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
        final int iterations = fitter1.getIterations();
        Assertions.assertNotEquals(0, iterations);
        // Test without convergence
        if (iterations > 2) {
            Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)6 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)4 Pair (org.apache.commons.math3.util.Pair)3 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)3 MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)3 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)3 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 ArrayList (java.util.ArrayList)1 MultivariateNormalMixtureExpectationMaximization (org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization)1 DenseVector (org.apache.mahout.math.DenseVector)1 Vector (org.apache.mahout.math.Vector)1 Test (org.junit.jupiter.api.Test)1 RandomGeneratorAdapter (uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter)1