use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximization method fit.
/**
* Fit a mixture model to the data supplied to the constructor.
*
* <p>The quality of the fit depends on the concavity of the data provided to the constructor and
* the initial mixture provided to this function. If the data has many local optima, multiple runs
* of the fitting function with different initial mixtures may be required to find the optimal
* solution. If a SingularMatrixException is encountered, it is possible that another
* initialisation would work.
*
* @param initialMixture model containing initial values of weights and multivariate normals
* @param maxIterations maximum iterations allowed for fit
* @param convergencePredicate convergence predicated used to test the logLikelihoods between
* successive iterations
* @return true if converged within the threshold
* @throws IllegalArgumentException if maxIterations is less than one or if initialMixture mean
* vector and data number of columns are not equal
* @throws SingularMatrixException if any component's covariance matrix is singular during
* fitting.
* @throws NonPositiveDefiniteMatrixException if any component's covariance matrix is not positive
* definite during fitting.
*/
public boolean fit(MixtureMultivariateGaussianDistribution initialMixture, int maxIterations, DoubleDoubleBiPredicate convergencePredicate) {
ValidationUtils.checkStrictlyPositive(maxIterations, "maxIterations");
ValidationUtils.checkNotNull(convergencePredicate, "convergencePredicate");
ValidationUtils.checkNotNull(initialMixture, "initialMixture");
final int n = data.length;
final int k = initialMixture.weights.length;
// Number of data columns.
final int numCols = data[0].length;
final int numMeanColumns = initialMixture.distributions[0].means.length;
ValidationUtils.checkArgument(numCols == numMeanColumns, "Mixture model dimension mismatch with data columns: %d != %d", numCols, numMeanColumns);
logLikelihood = -Double.MAX_VALUE;
iterations = 0;
// Initialize model to fit to initial mixture.
fittedModel = initialMixture;
while (iterations++ <= maxIterations) {
final double previousLogLikelihood = logLikelihood;
double sumLogLikelihood = 0;
// Weight and distribution of each component
final double[] weights = fittedModel.weights;
final MultivariateGaussianDistribution[] mvns = fittedModel.distributions;
// E-step: compute the data dependent parameters of the expectation function.
// The percentage of row's total density between a row and a component
final double[][] gamma = new double[n][k];
// Sum of gamma for each component
final double[] gammaSums = new double[k];
// Sum of gamma times its row for each each component
final double[][] gammaDataProdSums = new double[k][numCols];
// Cache for the weight multiplied by the distribution density
final double[] mvnDensity = new double[k];
for (int i = 0; i < n; i++) {
final double[] point = data[i];
// Compute densities for each component and the row density
double rowDensity = 0;
for (int j = 0; j < k; j++) {
final double d = weights[j] * mvns[j].density(point);
mvnDensity[j] = d;
rowDensity += d;
}
sumLogLikelihood += Math.log(rowDensity);
for (int j = 0; j < k; j++) {
gamma[i][j] = mvnDensity[j] / rowDensity;
gammaSums[j] += gamma[i][j];
for (int col = 0; col < numCols; col++) {
gammaDataProdSums[j][col] += gamma[i][j] * point[col];
}
}
}
logLikelihood = sumLogLikelihood;
// M-step: compute the new parameters based on the expectation function.
final double[] newWeights = new double[k];
final double[][] newMeans = new double[k][numCols];
for (int j = 0; j < k; j++) {
newWeights[j] = gammaSums[j] / n;
for (int col = 0; col < numCols; col++) {
newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
}
}
// Compute new covariance matrices.
// These are symmetric so we compute the triangular half.
final double[][][] newCovMats = new double[k][numCols][numCols];
final double[] vec = new double[numCols];
for (int i = 0; i < n; i++) {
final double[] point = data[i];
for (int j = 0; j < k; j++) {
subtract(point, newMeans[j], vec);
final double g = gamma[i][j];
// covariance = vec * vecT
// covariance[ii][jj] = vec[ii] * vec[jj] * gamma[i][j]
final double[][] covar = newCovMats[j];
for (int ii = 0; ii < numCols; ii++) {
// pre-compute
final double vig = vec[ii] * g;
final double[] covari = covar[ii];
for (int jj = 0; jj <= ii; jj++) {
covari[jj] += vig * vec[jj];
}
}
}
}
// Converting to arrays for use by fitted model
final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[k];
for (int j = 0; j < k; j++) {
// Make symmetric and normalise by gamma sum
final double norm = 1.0 / gammaSums[j];
final double[][] covar = newCovMats[j];
for (int ii = 0; ii < numCols; ii++) {
// diagonal
covar[ii][ii] *= norm;
// elements
for (int jj = 0; jj < ii; jj++) {
final double tmp = covar[ii][jj] * norm;
covar[ii][jj] = tmp;
covar[jj][ii] = tmp;
}
}
distributions[j] = new MultivariateGaussianDistribution(newMeans[j], covar);
}
// Update current model
fittedModel = MixtureMultivariateGaussianDistribution.create(newWeights, distributions);
// Check convergence
if (convergencePredicate.test(previousLogLikelihood, logLikelihood)) {
return true;
}
}
// No convergence
return false;
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximization method estimate.
/**
* Helper method to create a multivariate Gaussian mixture model which can be used to initialize
* {@link #fit(MixtureMultivariateGaussianDistribution)}.
*
* <p>The function is used to extract a value for ranking all points. These are then split
* uniformly into the specified number of components. This is equivalent to projection of the
* points onto a line and partitioning the points along the line uniformly; this cuts the points
* into sets using hyper-planes defined by the vector of the line. A good ranking metric is a dot
* product with a random unit vector uniformly sampled from the surface of an n-dimension
* hypersphere.
*
* <pre>
* double[] vec = ...;
* ToDoubleFunction<double[]> rankingMetric =
* values -> {
* double d = 0;
* for (int i = 0; i < vec.length; i++) {
* d += vec[i] * values[i];
* }
* return d;
* };
* </pre>
*
* <p>This method can be used with the data supplied to the instance constructor to try to
* determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
* model which will find the optimal solution or even converge.
*
* @param data data to estimate distribution
* @param numComponents number of components for estimated mixture
* @param rankingMetric the function to generate the ranking metric
* @return Multivariate Gaussian mixture model estimated from the data
* @throws IllegalArgumentException if data has less than 2 rows, if {@code numComponents < 2}, or
* if {@code numComponents} is greater than the number of data rows.
*/
public static MixtureMultivariateGaussianDistribution estimate(double[][] data, int numComponents, ToDoubleFunction<double[]> rankingMetric) {
ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
ValidationUtils.checkArgument(numComponents >= 2, "Multivariate Gaussian mixture requires at least 2 components: %d", numComponents);
ValidationUtils.checkArgument(numComponents <= data.length, "Number of components %d greater than data length %d", numComponents, data.length);
ValidationUtils.checkNotNull(rankingMetric, "rankingMetric");
final int numRows = data.length;
final int numCols = data[0].length;
ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
// Sort the data
final ProjectedData[] sortedData = new ProjectedData[numRows];
for (int i = 0; i < numRows; i++) {
sortedData[i] = new ProjectedData(data[i], rankingMetric.applyAsDouble(data[i]));
}
Arrays.sort(sortedData, (o1, o2) -> Double.compare(o1.value, o2.value));
// components of mixture model to be created
final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[numComponents];
// create a component based on data in each bin
for (int binIndex = 0; binIndex < numComponents; binIndex++) {
// minimum index (inclusive) from sorted data for this bin
final int minIndex = (binIndex * numRows) / numComponents;
// maximum index (exclusive) from sorted data for this bin
final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
// number of data records that will be in this bin
final int numBinRows = maxIndex - minIndex;
// data for this bin
final double[][] binData = new double[numBinRows][];
// mean of each column for the data in the this bin
final double[] columnMeans = new double[numCols];
// populate bin and create component
for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
final double[] values = sortedData[i].data;
binData[iBin] = values;
for (int j = 0; j < numCols; j++) {
columnMeans[j] += values[j];
}
}
SimpleArrayUtils.multiply(columnMeans, 1.0 / numBinRows);
distributions[binIndex] = new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, binData));
}
// uniform weight for each bin
final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
return new MixtureMultivariateGaussianDistribution(weights, distributions);
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximization method createMixed.
/**
* Helper method to create a multivariate Gaussian mixture model which can be used to initialize
* {@link #fit(MixtureMultivariateGaussianDistribution)}.
*
* <p>This method can be used with the data supplied to the instance constructor to try to
* determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
* model which will find the optimal solution or even converge.
*
* <p>The weights for each component will be uniform.
*
* @param data data for the distribution
* @param component the component for each data point
* @return Multivariate Gaussian mixture model estimated from the data
* @throws IllegalArgumentException if data has less than 2 rows, if there is a size mismatch
* between the data and components length, or if the number of components is less than 2.
*/
public static MixtureMultivariateGaussianDistribution createMixed(double[][] data, int[] component) {
ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
ValidationUtils.checkArgument(data.length == component.length, "Data and component size mismatch: %d != %d", data.length, component.length);
final int numRows = data.length;
final int numCols = data[0].length;
ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
// Sort the data
final ClassifiedData[] sortedData = new ClassifiedData[data.length];
for (int i = 0; i < numRows; i++) {
sortedData[i] = new ClassifiedData(data[i], component[i]);
}
Arrays.sort(sortedData, (o1, o2) -> Integer.compare(o1.value, o2.value));
ValidationUtils.checkArgument(sortedData[0].value != sortedData[sortedData.length - 1].value, "Mixture model requires at least 2 data components");
// components of mixture model to be created
final LocalList<MultivariateGaussianDistribution> distributions = new LocalList<>();
int from = 0;
while (from < sortedData.length) {
// Find the end
int to = from + 1;
final int comp = sortedData[from].value;
while (to < sortedData.length && sortedData[to].value == comp) {
to++;
}
// number of data records that will be in this component
final int numCompRows = to - from;
// data for this component
final double[][] compData = new double[numCompRows][];
// mean of each column for the data
final double[] columnMeans = new double[numCols];
// populate and create component
int count = 0;
for (int i = from; i < to; i++) {
final double[] values = sortedData[i].data;
compData[count++] = values;
for (int j = 0; j < numCols; j++) {
columnMeans[j] += values[j];
}
}
SimpleArrayUtils.multiply(columnMeans, 1.0 / numCompRows);
distributions.add(new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, compData)));
from = to;
}
// uniform weight for each bin
final int numComponents = distributions.size();
final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
return new MixtureMultivariateGaussianDistribution(weights, distributions.toArray(new MultivariateGaussianDistribution[0]));
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateUnmixedMultivariateGaussianDistribution.
@Test
void canCreateUnmixedMultivariateGaussianDistribution() {
final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[] means = getColumnMeans(data);
final double[][] covariances = getCovariance(data);
final MultivariateGaussianDistribution exp = MultivariateGaussianDistribution.create(means, covariances);
final MultivariateGaussianDistribution obs = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-6);
TestAssertions.assertArrayTest(exp.getMeans(), obs.getMeans(), test);
TestAssertions.assertArrayTest(exp.getCovariances(), obs.getCovariances(), test);
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canEstimateInitialMixture.
@SeededTest
void canEstimateInitialMixture(RandomSeed seed) {
// Test verses the Commons Math estimation
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
// Number of components
for (int n = 2; n <= 3; n++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
final double[][] data = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
final MixtureMultivariateGaussianDistribution model1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
final MixtureMultivariateNormalDistribution model2 = MultivariateNormalMixtureExpectationMaximization.estimate(data, n);
final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
final double[] weights = model1.getWeights();
final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
Assertions.assertEquals(n, comp.size());
Assertions.assertEquals(n, weights.length);
Assertions.assertEquals(n, distributions.length);
for (int i = 0; i < n; i++) {
// Must be binary equal for estimated model
Assertions.assertEquals(comp.get(i).getFirst(), weights[i], "weight");
final MultivariateNormalDistribution d = comp.get(i).getSecond();
TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
}
}
}
Aggregations