Search in sources :

Example 16 with NormalizedGaussianSampler

use of org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method createDataNd.

/**
 * Creates the data from a mixture of n ND Gaussian distributions. The length of the weights array
 * (and all other arrays) is the number of mixture components. The lengths of the nested means
 * (and all std.dev. array) is the number of dimensions for the Gaussian.
 *
 * @param sampleSize the sample size
 * @param rng the random generator
 * @param weights the weights for each component
 * @param means the means for the dimensions
 * @param stdDevs the std devs for the dimensions
 * @param correlations the correlations between the first dimension and the remaining dimensions
 * @return the double[][]
 */
private static double[][] createDataNd(int sampleSize, UniformRandomProvider rng, double[] weights, double[][] means, double[][] stdDevs, double[][] correlations) {
    // Directly sample Gaussian distributions
    final NormalizedGaussianSampler sampler = SamplerUtils.createNormalizedGaussianSampler(rng);
    final double[][] data = new double[sampleSize][];
    int count = 0;
    final int dimensions = means[0].length;
    // Ensure we have the correct number of samples
    final int[] nSamples = new int[weights.length];
    for (int i = 0; i < weights.length; i++) {
        // Sample from n ND Gaussian distributions
        nSamples[i] = (int) Math.round(weights[i] * sampleSize);
    }
    // Ensure we have the correct number of samples by leveling the counts
    while (MathUtils.sum(nSamples) > sampleSize) {
        nSamples[SimpleArrayUtils.findMaxIndex(nSamples)]--;
    }
    while (MathUtils.sum(nSamples) < sampleSize) {
        nSamples[SimpleArrayUtils.findMinIndex(nSamples)]++;
    }
    for (int i = 0; i < weights.length; i++) {
        // Sample from n ND Gaussian distributions
        final int n = nSamples[i];
        final double[][] samples = new double[n][dimensions];
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < dimensions; k++) {
                samples[j][k] = sampler.sample();
            }
        }
        // https://www.uvm.edu/~statdhtx/StatPages/More_Stuff/CorrGen.html
        for (int k = 1; k < dimensions; k++) {
            final double r = correlations[i][k - 1];
            final double a = r / Math.sqrt(1 - r * r);
            for (int j = 0; j < n; j++) {
                // Z = a*X + Y
                samples[j][k] += a * samples[j][0];
            }
        }
        // Apply mean and std.dev.
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < dimensions; k++) {
                samples[j][k] = samples[j][k] * stdDevs[i][k] + means[i][k];
            }
        }
        // Fill the data
        System.arraycopy(samples, 0, data, count, n);
        count += n;
    }
    return data;
}
Also used : NormalizedGaussianSampler(org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler)

Aggregations

NormalizedGaussianSampler (org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler)16 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)8 ArrayList (java.util.ArrayList)3 SharedStateContinuousSampler (org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler)3 Test (org.junit.jupiter.api.Test)3 StoredDataStatistics (uk.ac.sussex.gdsc.core.utils.StoredDataStatistics)3 MarsagliaTsangGammaSampler (uk.ac.sussex.gdsc.core.utils.rng.MarsagliaTsangGammaSampler)3 CalibrationWriter (uk.ac.sussex.gdsc.smlm.data.config.CalibrationWriter)3 DmttConfiguration (uk.ac.sussex.gdsc.smlm.results.DynamicMultipleTargetTracing.DmttConfiguration)3 MemoryPeakResults (uk.ac.sussex.gdsc.smlm.results.MemoryPeakResults)3 Statistics (uk.ac.sussex.gdsc.core.utils.Statistics)2 StoredData (uk.ac.sussex.gdsc.core.utils.StoredData)2 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)2 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TIntHashSet (gnu.trove.set.hash.TIntHashSet)1 IJ (ij.IJ)1 ImagePlus (ij.ImagePlus)1 Plot (ij.gui.Plot)1 Calibration (ij.measure.Calibration)1 PlugIn (ij.plugin.PlugIn)1