Search in sources :

Example 6 with MultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method fit.

/**
 * Fit a mixture model to the data supplied to the constructor.
 *
 * <p>The quality of the fit depends on the concavity of the data provided to the constructor and
 * the initial mixture provided to this function. If the data has many local optima, multiple runs
 * of the fitting function with different initial mixtures may be required to find the optimal
 * solution. If a SingularMatrixException is encountered, it is possible that another
 * initialisation would work.
 *
 * @param initialMixture model containing initial values of weights and multivariate normals
 * @param maxIterations maximum iterations allowed for fit
 * @param convergencePredicate convergence predicated used to test the logLikelihoods between
 *        successive iterations
 * @return true if converged within the threshold
 * @throws IllegalArgumentException if maxIterations is less than one or if initialMixture mean
 *         vector and data number of columns are not equal
 * @throws SingularMatrixException if any component's covariance matrix is singular during
 *         fitting.
 * @throws NonPositiveDefiniteMatrixException if any component's covariance matrix is not positive
 *         definite during fitting.
 */
public boolean fit(MixtureMultivariateGaussianDistribution initialMixture, int maxIterations, DoubleDoubleBiPredicate convergencePredicate) {
    ValidationUtils.checkStrictlyPositive(maxIterations, "maxIterations");
    ValidationUtils.checkNotNull(convergencePredicate, "convergencePredicate");
    ValidationUtils.checkNotNull(initialMixture, "initialMixture");
    final int n = data.length;
    final int k = initialMixture.weights.length;
    // Number of data columns.
    final int numCols = data[0].length;
    final int numMeanColumns = initialMixture.distributions[0].means.length;
    ValidationUtils.checkArgument(numCols == numMeanColumns, "Mixture model dimension mismatch with data columns: %d != %d", numCols, numMeanColumns);
    logLikelihood = -Double.MAX_VALUE;
    iterations = 0;
    // Initialize model to fit to initial mixture.
    fittedModel = initialMixture;
    while (iterations++ <= maxIterations) {
        final double previousLogLikelihood = logLikelihood;
        double sumLogLikelihood = 0;
        // Weight and distribution of each component
        final double[] weights = fittedModel.weights;
        final MultivariateGaussianDistribution[] mvns = fittedModel.distributions;
        // E-step: compute the data dependent parameters of the expectation function.
        // The percentage of row's total density between a row and a component
        final double[][] gamma = new double[n][k];
        // Sum of gamma for each component
        final double[] gammaSums = new double[k];
        // Sum of gamma times its row for each each component
        final double[][] gammaDataProdSums = new double[k][numCols];
        // Cache for the weight multiplied by the distribution density
        final double[] mvnDensity = new double[k];
        for (int i = 0; i < n; i++) {
            final double[] point = data[i];
            // Compute densities for each component and the row density
            double rowDensity = 0;
            for (int j = 0; j < k; j++) {
                final double d = weights[j] * mvns[j].density(point);
                mvnDensity[j] = d;
                rowDensity += d;
            }
            sumLogLikelihood += Math.log(rowDensity);
            for (int j = 0; j < k; j++) {
                gamma[i][j] = mvnDensity[j] / rowDensity;
                gammaSums[j] += gamma[i][j];
                for (int col = 0; col < numCols; col++) {
                    gammaDataProdSums[j][col] += gamma[i][j] * point[col];
                }
            }
        }
        logLikelihood = sumLogLikelihood;
        // M-step: compute the new parameters based on the expectation function.
        final double[] newWeights = new double[k];
        final double[][] newMeans = new double[k][numCols];
        for (int j = 0; j < k; j++) {
            newWeights[j] = gammaSums[j] / n;
            for (int col = 0; col < numCols; col++) {
                newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
            }
        }
        // Compute new covariance matrices.
        // These are symmetric so we compute the triangular half.
        final double[][][] newCovMats = new double[k][numCols][numCols];
        final double[] vec = new double[numCols];
        for (int i = 0; i < n; i++) {
            final double[] point = data[i];
            for (int j = 0; j < k; j++) {
                subtract(point, newMeans[j], vec);
                final double g = gamma[i][j];
                // covariance = vec * vecT
                // covariance[ii][jj] = vec[ii] * vec[jj] * gamma[i][j]
                final double[][] covar = newCovMats[j];
                for (int ii = 0; ii < numCols; ii++) {
                    // pre-compute
                    final double vig = vec[ii] * g;
                    final double[] covari = covar[ii];
                    for (int jj = 0; jj <= ii; jj++) {
                        covari[jj] += vig * vec[jj];
                    }
                }
            }
        }
        // Converting to arrays for use by fitted model
        final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[k];
        for (int j = 0; j < k; j++) {
            // Make symmetric and normalise by gamma sum
            final double norm = 1.0 / gammaSums[j];
            final double[][] covar = newCovMats[j];
            for (int ii = 0; ii < numCols; ii++) {
                // diagonal
                covar[ii][ii] *= norm;
                // elements
                for (int jj = 0; jj < ii; jj++) {
                    final double tmp = covar[ii][jj] * norm;
                    covar[ii][jj] = tmp;
                    covar[jj][ii] = tmp;
                }
            }
            distributions[j] = new MultivariateGaussianDistribution(newMeans[j], covar);
        }
        // Update current model
        fittedModel = MixtureMultivariateGaussianDistribution.create(newWeights, distributions);
        // Check convergence
        if (convergencePredicate.test(previousLogLikelihood, logLikelihood)) {
            return true;
        }
    }
    // No convergence
    return false;
}
Also used : MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 7 with MultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method estimate.

/**
 * Helper method to create a multivariate Gaussian mixture model which can be used to initialize
 * {@link #fit(MixtureMultivariateGaussianDistribution)}.
 *
 * <p>The function is used to extract a value for ranking all points. These are then split
 * uniformly into the specified number of components. This is equivalent to projection of the
 * points onto a line and partitioning the points along the line uniformly; this cuts the points
 * into sets using hyper-planes defined by the vector of the line. A good ranking metric is a dot
 * product with a random unit vector uniformly sampled from the surface of an n-dimension
 * hypersphere.
 *
 * <pre>
 * double[] vec = ...;
 * ToDoubleFunction&lt;double[]&gt; rankingMetric =
 *   values -&gt; {
 *     double d = 0;
 *     for (int i = 0; i &lt; vec.length; i++) {
 *       d += vec[i] * values[i];
 *     }
 *     return d;
 *   };
 * </pre>
 *
 * <p>This method can be used with the data supplied to the instance constructor to try to
 * determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
 * model which will find the optimal solution or even converge.
 *
 * @param data data to estimate distribution
 * @param numComponents number of components for estimated mixture
 * @param rankingMetric the function to generate the ranking metric
 * @return Multivariate Gaussian mixture model estimated from the data
 * @throws IllegalArgumentException if data has less than 2 rows, if {@code numComponents < 2}, or
 *         if {@code numComponents} is greater than the number of data rows.
 */
public static MixtureMultivariateGaussianDistribution estimate(double[][] data, int numComponents, ToDoubleFunction<double[]> rankingMetric) {
    ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
    ValidationUtils.checkArgument(numComponents >= 2, "Multivariate Gaussian mixture requires at least 2 components: %d", numComponents);
    ValidationUtils.checkArgument(numComponents <= data.length, "Number of components %d greater than data length %d", numComponents, data.length);
    ValidationUtils.checkNotNull(rankingMetric, "rankingMetric");
    final int numRows = data.length;
    final int numCols = data[0].length;
    ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
    // Sort the data
    final ProjectedData[] sortedData = new ProjectedData[numRows];
    for (int i = 0; i < numRows; i++) {
        sortedData[i] = new ProjectedData(data[i], rankingMetric.applyAsDouble(data[i]));
    }
    Arrays.sort(sortedData, (o1, o2) -> Double.compare(o1.value, o2.value));
    // components of mixture model to be created
    final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[numComponents];
    // create a component based on data in each bin
    for (int binIndex = 0; binIndex < numComponents; binIndex++) {
        // minimum index (inclusive) from sorted data for this bin
        final int minIndex = (binIndex * numRows) / numComponents;
        // maximum index (exclusive) from sorted data for this bin
        final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
        // number of data records that will be in this bin
        final int numBinRows = maxIndex - minIndex;
        // data for this bin
        final double[][] binData = new double[numBinRows][];
        // mean of each column for the data in the this bin
        final double[] columnMeans = new double[numCols];
        // populate bin and create component
        for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
            final double[] values = sortedData[i].data;
            binData[iBin] = values;
            for (int j = 0; j < numCols; j++) {
                columnMeans[j] += values[j];
            }
        }
        SimpleArrayUtils.multiply(columnMeans, 1.0 / numBinRows);
        distributions[binIndex] = new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, binData));
    }
    // uniform weight for each bin
    final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
    return new MixtureMultivariateGaussianDistribution(weights, distributions);
}
Also used : MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 8 with MultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method createMixed.

/**
 * Helper method to create a multivariate Gaussian mixture model which can be used to initialize
 * {@link #fit(MixtureMultivariateGaussianDistribution)}.
 *
 * <p>This method can be used with the data supplied to the instance constructor to try to
 * determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
 * model which will find the optimal solution or even converge.
 *
 * <p>The weights for each component will be uniform.
 *
 * @param data data for the distribution
 * @param component the component for each data point
 * @return Multivariate Gaussian mixture model estimated from the data
 * @throws IllegalArgumentException if data has less than 2 rows, if there is a size mismatch
 *         between the data and components length, or if the number of components is less than 2.
 */
public static MixtureMultivariateGaussianDistribution createMixed(double[][] data, int[] component) {
    ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
    ValidationUtils.checkArgument(data.length == component.length, "Data and component size mismatch: %d != %d", data.length, component.length);
    final int numRows = data.length;
    final int numCols = data[0].length;
    ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
    // Sort the data
    final ClassifiedData[] sortedData = new ClassifiedData[data.length];
    for (int i = 0; i < numRows; i++) {
        sortedData[i] = new ClassifiedData(data[i], component[i]);
    }
    Arrays.sort(sortedData, (o1, o2) -> Integer.compare(o1.value, o2.value));
    ValidationUtils.checkArgument(sortedData[0].value != sortedData[sortedData.length - 1].value, "Mixture model requires at least 2 data components");
    // components of mixture model to be created
    final LocalList<MultivariateGaussianDistribution> distributions = new LocalList<>();
    int from = 0;
    while (from < sortedData.length) {
        // Find the end
        int to = from + 1;
        final int comp = sortedData[from].value;
        while (to < sortedData.length && sortedData[to].value == comp) {
            to++;
        }
        // number of data records that will be in this component
        final int numCompRows = to - from;
        // data for this component
        final double[][] compData = new double[numCompRows][];
        // mean of each column for the data
        final double[] columnMeans = new double[numCols];
        // populate and create component
        int count = 0;
        for (int i = from; i < to; i++) {
            final double[] values = sortedData[i].data;
            compData[count++] = values;
            for (int j = 0; j < numCols; j++) {
                columnMeans[j] += values[j];
            }
        }
        SimpleArrayUtils.multiply(columnMeans, 1.0 / numCompRows);
        distributions.add(new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, compData)));
        from = to;
    }
    // uniform weight for each bin
    final int numComponents = distributions.size();
    final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
    return new MixtureMultivariateGaussianDistribution(weights, distributions.toArray(new MultivariateGaussianDistribution[0]));
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 9 with MultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateUnmixedMultivariateGaussianDistribution.

@Test
void canCreateUnmixedMultivariateGaussianDistribution() {
    final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[] means = getColumnMeans(data);
    final double[][] covariances = getCovariance(data);
    final MultivariateGaussianDistribution exp = MultivariateGaussianDistribution.create(means, covariances);
    final MultivariateGaussianDistribution obs = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-6);
    TestAssertions.assertArrayTest(exp.getMeans(), obs.getMeans(), test);
    TestAssertions.assertArrayTest(exp.getCovariances(), obs.getCovariances(), test);
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Example 10 with MultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canEstimateInitialMixture.

@SeededTest
void canEstimateInitialMixture(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution model1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MixtureMultivariateNormalDistribution model2 = MultivariateNormalMixtureExpectationMaximization.estimate(data, n);
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            // Must be binary equal for estimated model
            Assertions.assertEquals(comp.get(i).getFirst(), weights[i], "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)12 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)9 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)7 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)4 Test (org.junit.jupiter.api.Test)4 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)4 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)3 Pair (org.apache.commons.math3.util.Pair)3 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)3 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)3 TextWindow (ij.text.TextWindow)2 TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)1 TIntHashSet (gnu.trove.set.hash.TIntHashSet)1 IJ (ij.IJ)1 ImagePlus (ij.ImagePlus)1