Search in sources :

Example 1 with MultivariateNormalDistribution

use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method createData2d.

/**
 * Creates the data from a mixture of n 2D Gaussian distributions. The length of the weights array
 * (and all other arrays) is the number of mixture components.
 *
 * @param sampleSize the sample size
 * @param rng the random generator
 * @param weights the weights for each component
 * @param means the means for the x and y dimensions
 * @param stdDevs the std devs for the x and y dimensions
 * @param correlations the correlations between the x and y dimensions
 * @return the double[][]
 */
private static double[][] createData2d(int sampleSize, UniformRandomProvider rng, double[] weights, double[][] means, double[][] stdDevs, double[] correlations) {
    // Use Commons Math for sampling
    final ArrayList<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<>();
    for (int i = 0; i < weights.length; i++) {
        // Create covariance matrix
        final double sx = stdDevs[i][0];
        final double sy = stdDevs[i][1];
        final double sxsy = correlations[i] * sx * sy;
        final double[][] covar = new double[][] { { sx * sx, sxsy }, { sxsy, sy * sy } };
        components.add(new Pair<>(weights[i], new MultivariateNormalDistribution(means[i], covar)));
    }
    final MixtureMultivariateNormalDistribution dist = new MixtureMultivariateNormalDistribution(new RandomGeneratorAdapter(rng), components);
    return dist.sample(sampleSize);
}
Also used : MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) ArrayList(java.util.ArrayList) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) Pair(org.apache.commons.math3.util.Pair)

Example 2 with MultivariateNormalDistribution

use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMultivariateGaussianDistribution.

@Test
void canCreateMultivariateGaussianDistribution() {
    final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
    final double[] means = getColumnMeans(data);
    final double[][] covariances = getCovariance(data);
    final MultivariateGaussianDistribution dist = MultivariateGaussianDistribution.create(means, covariances);
    Assertions.assertSame(means, dist.getMeans());
    Assertions.assertSame(covariances, dist.getCovariances());
    final double[] sd = dist.getStandardDeviations();
    Assertions.assertEquals(covariances.length, sd.length);
    for (int i = 0; i < sd.length; i++) {
        Assertions.assertEquals(Math.sqrt(covariances[i][i]), sd[i]);
    }
    // Test against Apache commons
    final MultivariateNormalDistribution expDist = new MultivariateNormalDistribution(means, covariances);
    for (final double[] x : data) {
        Assertions.assertEquals(expDist.density(x), dist.density(x));
    }
}
Also used : MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Example 3 with MultivariateNormalDistribution

use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project lvm4j by dirmeier.

the class GaussianMixtureModel method logLik.

private double logLik(final int K, GaussianMixtureComponents comps) {
    MultivariateNormalDistribution[] mvts = new MultivariateNormalDistribution[K];
    for (int i = 0; i < K; i++) {
        mvts[i] = new MultivariateNormalDistribution(comps.means(i), comps.var(i));
    }
    double loglik = 0;
    for (int i = 0; i < _N; i++) {
        double lg = 0;
        for (int j = 0; j < K; j++) {
            lg += comps.mixingWeight(j) * mvts[j].density(_X[i]);
        }
        loglik += Math.log(lg);
    }
    return loglik;
}
Also used : MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution)

Example 4 with MultivariateNormalDistribution

use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project pyramid by cheng-li.

the class MultiLabelSynthesizer method gaussianNoise.

/**
 * 2 labels, 3 features, multi-variate gaussian noise
 * y0: w=(0,1,0)
 * y1: w=(1,0,0)
 * y2: w=(0,0,1)
 * @return
 */
public static MultiLabelClfDataSet gaussianNoise(int numData) {
    int numClass = 3;
    int numFeature = 3;
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    // generate weights
    Vector[] weights = new Vector[numClass];
    for (int k = 0; k < numClass; k++) {
        Vector vector = new DenseVector(numFeature);
        weights[k] = vector;
    }
    weights[0].set(1, 1);
    weights[1].set(0, 1);
    weights[2].set(2, 1);
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    double[] means = new double[numClass];
    double[][] covars = new double[numClass][numClass];
    covars[0][0] = 0.5;
    covars[0][1] = 0.02;
    covars[1][0] = 0.02;
    covars[0][2] = -0.03;
    covars[2][0] = -0.03;
    covars[1][1] = 0.2;
    covars[1][2] = -0.03;
    covars[2][1] = -0.03;
    covars[2][2] = 0.3;
    MultivariateNormalDistribution distribution = new MultivariateNormalDistribution(means, covars);
    // assign labels
    int numFlipped = 0;
    for (int i = 0; i < numData; i++) {
        double[] noises = distribution.sample();
        for (int k = 0; k < numClass; k++) {
            double dot = weights[k].dot(dataSet.getRow(i));
            double score = dot + noises[k];
            if (score >= 0) {
                dataSet.addLabel(i, k);
            }
            if (dot * score < 0) {
                numFlipped += 1;
            }
        }
    }
    System.out.println("number of flipped bits = " + numFlipped);
    return dataSet;
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) DenseVector(org.apache.mahout.math.DenseVector) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution)

Example 5 with MultivariateNormalDistribution

use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canEstimateInitialMixture.

@SeededTest
void canEstimateInitialMixture(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution model1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MixtureMultivariateNormalDistribution model2 = MultivariateNormalMixtureExpectationMaximization.estimate(data, n);
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            // Must be binary equal for estimated model
            Assertions.assertEquals(comp.get(i).getFirst(), weights[i], "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)6 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)4 Pair (org.apache.commons.math3.util.Pair)3 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)3 MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)3 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)3 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 ArrayList (java.util.ArrayList)1 MultivariateNormalMixtureExpectationMaximization (org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization)1 DenseVector (org.apache.mahout.math.DenseVector)1 Vector (org.apache.mahout.math.Vector)1 Test (org.junit.jupiter.api.Test)1 RandomGeneratorAdapter (uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter)1