use of org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method createDataNd.
/**
* Creates the data from a mixture of n ND Gaussian distributions. The length of the weights array
* (and all other arrays) is the number of mixture components. The lengths of the nested means
* (and all std.dev. array) is the number of dimensions for the Gaussian.
*
* @param sampleSize the sample size
* @param rng the random generator
* @param weights the weights for each component
* @param means the means for the dimensions
* @param stdDevs the std devs for the dimensions
* @param correlations the correlations between the first dimension and the remaining dimensions
* @return the double[][]
*/
private static double[][] createDataNd(int sampleSize, UniformRandomProvider rng, double[] weights, double[][] means, double[][] stdDevs, double[][] correlations) {
// Directly sample Gaussian distributions
final NormalizedGaussianSampler sampler = SamplerUtils.createNormalizedGaussianSampler(rng);
final double[][] data = new double[sampleSize][];
int count = 0;
final int dimensions = means[0].length;
// Ensure we have the correct number of samples
final int[] nSamples = new int[weights.length];
for (int i = 0; i < weights.length; i++) {
// Sample from n ND Gaussian distributions
nSamples[i] = (int) Math.round(weights[i] * sampleSize);
}
// Ensure we have the correct number of samples by leveling the counts
while (MathUtils.sum(nSamples) > sampleSize) {
nSamples[SimpleArrayUtils.findMaxIndex(nSamples)]--;
}
while (MathUtils.sum(nSamples) < sampleSize) {
nSamples[SimpleArrayUtils.findMinIndex(nSamples)]++;
}
for (int i = 0; i < weights.length; i++) {
// Sample from n ND Gaussian distributions
final int n = nSamples[i];
final double[][] samples = new double[n][dimensions];
for (int j = 0; j < n; j++) {
for (int k = 0; k < dimensions; k++) {
samples[j][k] = sampler.sample();
}
}
// https://www.uvm.edu/~statdhtx/StatPages/More_Stuff/CorrGen.html
for (int k = 1; k < dimensions; k++) {
final double r = correlations[i][k - 1];
final double a = r / Math.sqrt(1 - r * r);
for (int j = 0; j < n; j++) {
// Z = a*X + Y
samples[j][k] += a * samples[j][0];
}
}
// Apply mean and std.dev.
for (int j = 0; j < n; j++) {
for (int k = 0; k < dimensions; k++) {
samples[j][k] = samples[j][k] * stdDevs[i][k] + means[i][k];
}
}
// Fill the data
System.arraycopy(samples, 0, data, count, n);
count += n;
}
return data;
}
Aggregations