use of org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeed.
/**
* Test the speed of implementations of the expectation maximization algorithm with a mixture of n
* ND Gaussian distributions.
*
* @param seed the seed
*/
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeed(RandomSeed seed) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
final MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate relChecker = TestHelper.doublesAreClose(1e-6)::test;
// Create data
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
for (int n = 2; n <= 3; n++) {
for (int dim = 2; dim <= 4; dim++) {
final double[][][] data = new double[10][][];
final int nCorrelations = dim - 1;
for (int i = 0; i < data.length; i++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, dim, rng, -5, 5);
final double[][] sampleStdDevs = create(n, dim, rng, 1, 10);
final double[][] sampleCorrelations = IntStream.range(0, n).mapToObj(component -> create(nCorrelations, rng, -0.9, 0.9)).toArray(double[][]::new);
data[i] = createDataNd(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
}
final int numComponents = n;
// Time initial estimation and fitting
final TimingService ts = new TimingService();
ts.execute(new FittingSpeedTask("Commons n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC rel 1e-6 n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents), 1000, relChecker);
return fitter.getLogLikelihood();
}
});
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
// More than twice as fast
Assertions.assertTrue(ts.get(-2).getMean() < ts.get(-3).getMean() / 2);
}
}
}
use of org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeedWithDifferentNumberOfComponents.
/**
* Test the speed of implementations of the expectation maximization algorithm with a mixture of n
* 2D Gaussian distributions.
*
* @param seed the seed
*/
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeedWithDifferentNumberOfComponents(RandomSeed seed) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
// Create data
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
for (int n = 2; n <= 4; n++) {
final double[][][] data = new double[10][][];
for (int i = 0; i < data.length; i++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
data[i] = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
}
final int numComponents = n;
// Time initial estimation and fitting
final TimingService ts = new TimingService();
ts.execute(new FittingSpeedTask("Commons n=" + n + " 2D", data) {
@Override
Object run(double[][] data) {
final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC n=" + n + " 2D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
// More than twice as fast
Assertions.assertTrue(ts.get(-1).getMean() < ts.get(-2).getMean() / 2);
}
}
use of org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization in project gatk-protected by broadinstitute.
the class CoverageDropoutDetector method retrieveGaussianMixtureModelForFilteredTargets.
/** <p>Produces a Gaussian mixture model based on the difference between targets and segment means.</p>
* <p>Filters targets to populations where more than the minProportion lie in a single segment.</p>
* <p>Returns null if no pass filtering. Please note that in these cases,
* in the rest of this class, we use this to assume that a GMM is not a good model.</p>
*
* @param segments -- segments with segment mean in log2 copy ratio space
* @param targets -- targets with a log2 copy ratio estimate
* @param minProportion -- minimum proportion of all targets that a given segment must have in order to be used
* in the evaluation
* @param numComponents -- number of components to use in the GMM. Usually, this is 2.
* @return never {@code null}. Fitting result with indications whether it converged or was even attempted.
*/
private MixtureMultivariateNormalFitResult retrieveGaussianMixtureModelForFilteredTargets(final List<ModeledSegment> segments, final TargetCollection<ReadCountRecord.SingleSampleRecord> targets, double minProportion, int numComponents) {
// For each target in a segment that contains enough targets, normalize the difference against the segment mean
// and collapse the filtered targets into the copy ratio estimates.
final List<Double> filteredTargetsSegDiff = getNumProbeFilteredTargetList(segments, targets, minProportion);
if (filteredTargetsSegDiff.size() < numComponents) {
return new MixtureMultivariateNormalFitResult(null, false, false);
}
// Assume that Apache Commons wants data points in the first dimension.
// Note that second dimension of length 2 (instead of 1) is to wrok around funny Apache commons API.
final double[][] filteredTargetsSegDiff2d = new double[filteredTargetsSegDiff.size()][2];
// Convert the filtered targets into 2d array (even if second dimension is length 1). The second dimension is
// uncorrelated Gaussian. This is only to get around funny API in Apache Commons, which will throw an
// exception if the length of the second dimension is < 2
final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(RANDOM_SEED));
final NormalDistribution nd = new NormalDistribution(rng, 0, .1);
for (int i = 0; i < filteredTargetsSegDiff.size(); i++) {
filteredTargetsSegDiff2d[i][0] = filteredTargetsSegDiff.get(i);
filteredTargetsSegDiff2d[i][1] = nd.sample();
}
final MixtureMultivariateNormalDistribution estimateEM0 = MultivariateNormalMixtureExpectationMaximization.estimate(filteredTargetsSegDiff2d, numComponents);
final MultivariateNormalMixtureExpectationMaximization multivariateNormalMixtureExpectationMaximization = new MultivariateNormalMixtureExpectationMaximization(filteredTargetsSegDiff2d);
try {
multivariateNormalMixtureExpectationMaximization.fit(estimateEM0);
} catch (final MaxCountExceededException | ConvergenceException | SingularMatrixException e) {
// did not converge. Include the model as it was when the exception was thrown.
return new MixtureMultivariateNormalFitResult(multivariateNormalMixtureExpectationMaximization.getFittedModel(), false, true);
}
return new MixtureMultivariateNormalFitResult(multivariateNormalMixtureExpectationMaximization.getFittedModel(), true, true);
}
use of org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization in project gatk by broadinstitute.
the class CoverageDropoutDetector method retrieveGaussianMixtureModelForFilteredTargets.
/** <p>Produces a Gaussian mixture model based on the difference between targets and segment means.</p>
* <p>Filters targets to populations where more than the minProportion lie in a single segment.</p>
* <p>Returns null if no pass filtering. Please note that in these cases,
* in the rest of this class, we use this to assume that a GMM is not a good model.</p>
*
* @param segments -- segments with segment mean in log2 copy ratio space
* @param targets -- targets with a log2 copy ratio estimate
* @param minProportion -- minimum proportion of all targets that a given segment must have in order to be used
* in the evaluation
* @param numComponents -- number of components to use in the GMM. Usually, this is 2.
* @return never {@code null}. Fitting result with indications whether it converged or was even attempted.
*/
private MixtureMultivariateNormalFitResult retrieveGaussianMixtureModelForFilteredTargets(final List<ModeledSegment> segments, final TargetCollection<ReadCountRecord.SingleSampleRecord> targets, double minProportion, int numComponents) {
// For each target in a segment that contains enough targets, normalize the difference against the segment mean
// and collapse the filtered targets into the copy ratio estimates.
final List<Double> filteredTargetsSegDiff = getNumProbeFilteredTargetList(segments, targets, minProportion);
if (filteredTargetsSegDiff.size() < numComponents) {
return new MixtureMultivariateNormalFitResult(null, false, false);
}
// Assume that Apache Commons wants data points in the first dimension.
// Note that second dimension of length 2 (instead of 1) is to wrok around funny Apache commons API.
final double[][] filteredTargetsSegDiff2d = new double[filteredTargetsSegDiff.size()][2];
// Convert the filtered targets into 2d array (even if second dimension is length 1). The second dimension is
// uncorrelated Gaussian. This is only to get around funny API in Apache Commons, which will throw an
// exception if the length of the second dimension is < 2
final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(RANDOM_SEED));
final NormalDistribution nd = new NormalDistribution(rng, 0, .1);
for (int i = 0; i < filteredTargetsSegDiff.size(); i++) {
filteredTargetsSegDiff2d[i][0] = filteredTargetsSegDiff.get(i);
filteredTargetsSegDiff2d[i][1] = nd.sample();
}
final MixtureMultivariateNormalDistribution estimateEM0 = MultivariateNormalMixtureExpectationMaximization.estimate(filteredTargetsSegDiff2d, numComponents);
final MultivariateNormalMixtureExpectationMaximization multivariateNormalMixtureExpectationMaximization = new MultivariateNormalMixtureExpectationMaximization(filteredTargetsSegDiff2d);
try {
multivariateNormalMixtureExpectationMaximization.fit(estimateEM0);
} catch (final MaxCountExceededException | ConvergenceException | SingularMatrixException e) {
// did not converge. Include the model as it was when the exception was thrown.
return new MixtureMultivariateNormalFitResult(multivariateNormalMixtureExpectationMaximization.getFittedModel(), false, true);
}
return new MixtureMultivariateNormalFitResult(multivariateNormalMixtureExpectationMaximization.getFittedModel(), true, true);
}
use of org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.
@SeededTest
void canFit(RandomSeed seed) {
// Test verses the Commons Math estimation
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
final int sampleSize = 1000;
// Number of components
for (int n = 2; n <= 3; n++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
Assertions.assertTrue(fitter1.fit(initialModel1));
final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
final double ll1 = fitter1.getLogLikelihood() / sampleSize;
Assertions.assertNotEquals(0, ll1);
final double ll2 = fitter2.getLogLikelihood();
TestAssertions.assertTest(ll2, ll1, test);
final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
Assertions.assertNotNull(model1);
final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
// Check fitted models are the same
final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
final double[] weights = model1.getWeights();
final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
Assertions.assertEquals(n, comp.size());
Assertions.assertEquals(n, weights.length);
Assertions.assertEquals(n, distributions.length);
for (int i = 0; i < n; i++) {
TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
final MultivariateNormalDistribution d = comp.get(i).getSecond();
TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
}
final int iterations = fitter1.getIterations();
Assertions.assertNotEquals(0, iterations);
// Test without convergence
if (iterations > 2) {
Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
}
}
}
Aggregations