Search in sources :

Example 1 with NormalDistribution

use of org.apache.commons.math3.distribution.NormalDistribution in project deeplearning4j by deeplearning4j.

the class TestReconstructionDistributions method testGaussianLogProb.

@Test
public void testGaussianLogProb() {
    Nd4j.getRandom().setSeed(12345);
    int inputSize = 4;
    int[] mbs = new int[] { 1, 2, 5 };
    for (boolean average : new boolean[] { true, false }) {
        for (int minibatch : mbs) {
            INDArray x = Nd4j.rand(minibatch, inputSize);
            INDArray mean = Nd4j.randn(minibatch, inputSize);
            INDArray logStdevSquared = Nd4j.rand(minibatch, inputSize).subi(0.5);
            INDArray distributionParams = Nd4j.createUninitialized(new int[] { minibatch, 2 * inputSize });
            distributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(0, inputSize)).assign(mean);
            distributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(inputSize, 2 * inputSize)).assign(logStdevSquared);
            ReconstructionDistribution dist = new GaussianReconstructionDistribution("identity");
            double negLogProb = dist.negLogProbability(x, distributionParams, average);
            INDArray exampleNegLogProb = dist.exampleNegLogProbability(x, distributionParams);
            assertArrayEquals(new int[] { minibatch, 1 }, exampleNegLogProb.shape());
            //Calculate the same thing, but using Apache Commons math
            double logProbSum = 0.0;
            for (int i = 0; i < minibatch; i++) {
                double exampleSum = 0.0;
                for (int j = 0; j < inputSize; j++) {
                    double mu = mean.getDouble(i, j);
                    double logSigma2 = logStdevSquared.getDouble(i, j);
                    double sigma = Math.sqrt(Math.exp(logSigma2));
                    NormalDistribution nd = new NormalDistribution(mu, sigma);
                    double xVal = x.getDouble(i, j);
                    double thisLogProb = nd.logDensity(xVal);
                    logProbSum += thisLogProb;
                    exampleSum += thisLogProb;
                }
                assertEquals(-exampleNegLogProb.getDouble(i), exampleSum, 1e-6);
            }
            double expNegLogProb;
            if (average) {
                expNegLogProb = -logProbSum / minibatch;
            } else {
                expNegLogProb = -logProbSum;
            }
            //                System.out.println(expLogProb + "\t" + logProb + "\t" + (logProb / expLogProb));
            assertEquals(expNegLogProb, negLogProb, 1e-6);
            //Also: check random sampling...
            int count = minibatch * inputSize;
            INDArray arr = Nd4j.linspace(-3, 3, count).reshape(minibatch, inputSize);
            INDArray sampleMean = dist.generateAtMean(arr);
            INDArray sampleRandom = dist.generateRandom(arr);
        }
    }
}
Also used : GaussianReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) GaussianReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution) ReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution) ExponentialReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.ExponentialReconstructionDistribution) BernoulliReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution) Test(org.junit.Test)

Example 2 with NormalDistribution

use of org.apache.commons.math3.distribution.NormalDistribution in project EnrichmentMapApp by BaderLab.

the class MannWhitneyUTestSided method calculateAsymptoticPValue.

/**
     * @param Umin smallest Mann-Whitney U value
     * @param Umin smallest Mann-Whitney U1 value
     * @param Umin smallest Mann-Whitney U2 value
     * @param n1 number of subjects in first sample
     * @param n2 number of subjects in second sample
     * @return two-sided asymptotic p-value
     * @throws ConvergenceException if the p-value can not be computed
     * due to a convergence error
     * @throws MaxCountExceededException if the maximum number of
     * iterations is exceeded
     */
private double calculateAsymptoticPValue(final double Umin, final double U1, final double U2, final int n1, final int n2, final Type side) throws ConvergenceException, MaxCountExceededException {
    /* long multiplication to avoid overflow (double not used due to efficiency
         * and to avoid precision loss)
         */
    final long n1n2prod = (long) n1 * n2;
    // http://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U#Normal_approximation
    final double EU = n1n2prod / 2.0;
    final double VarU = n1n2prod * (n1 + n2 + 1) / 12.0;
    final double z = (Umin - EU) / FastMath.sqrt(VarU);
    // No try-catch or advertised exception because args are valid
    final NormalDistribution standardNormal = new NormalDistribution(0, 1);
    double p = 2 * standardNormal.cumulativeProbability(z);
    if (side == Type.TWO_SIDED) {
        return p;
    }
    if (side == Type.LESS) {
        if (U1 < U2) {
            return 0.5 * p;
        } else {
            return 1.0 - (0.5 * p);
        }
    } else {
        if (U1 > U2) {
            return 0.5 * p;
        } else {
            return 1.0 - (0.5 * p);
        }
    }
}
Also used : NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution)

Example 3 with NormalDistribution

use of org.apache.commons.math3.distribution.NormalDistribution in project gatk by broadinstitute.

the class SliceSamplerUnitTest method testInitialPointOutOfRange.

@Test(expectedExceptions = IllegalArgumentException.class)
public void testInitialPointOutOfRange() {
    rng.setSeed(RANDOM_SEED);
    final double mean = 5.;
    final double standardDeviation = 0.75;
    final NormalDistribution normalDistribution = new NormalDistribution(mean, standardDeviation);
    final Function<Double, Double> normalLogPDF = normalDistribution::logDensity;
    final double xInitial = -10.;
    final double xMin = 0.;
    final double xMax = 1.;
    final double width = 0.5;
    final SliceSampler normalSampler = new SliceSampler(rng, normalLogPDF, xMin, xMax, width);
    normalSampler.sample(xInitial);
}
Also used : NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) Test(org.testng.annotations.Test)

Example 4 with NormalDistribution

use of org.apache.commons.math3.distribution.NormalDistribution in project gatk by broadinstitute.

the class SliceSamplerUnitTest method testSliceSamplingOfNormalDistribution.

/**
     * Test slice sampling of a normal distribution.  Checks that input mean and standard deviation are recovered
     * by 10000 samples to a relative error of 0.5% and 2%, respectively.
     */
@Test
public void testSliceSamplingOfNormalDistribution() {
    rng.setSeed(RANDOM_SEED);
    final double mean = 5.;
    final double standardDeviation = 0.75;
    final NormalDistribution normalDistribution = new NormalDistribution(mean, standardDeviation);
    final Function<Double, Double> normalLogPDF = normalDistribution::logDensity;
    final double xInitial = 1.;
    final double xMin = Double.NEGATIVE_INFINITY;
    final double xMax = Double.POSITIVE_INFINITY;
    final double width = 0.5;
    final int numSamples = 10000;
    final SliceSampler normalSampler = new SliceSampler(rng, normalLogPDF, xMin, xMax, width);
    final double[] samples = Doubles.toArray(normalSampler.sample(xInitial, numSamples));
    final double sampleMean = new Mean().evaluate(samples);
    final double sampleStandardDeviation = new StandardDeviation().evaluate(samples);
    Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
    Assert.assertEquals(relativeError(sampleStandardDeviation, standardDeviation), 0., 0.02);
}
Also used : Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) StandardDeviation(org.apache.commons.math3.stat.descriptive.moment.StandardDeviation) Test(org.testng.annotations.Test)

Example 5 with NormalDistribution

use of org.apache.commons.math3.distribution.NormalDistribution in project gatk-protected by broadinstitute.

the class CoverageDropoutDetectorTest method getUnivariateGaussianTargetsWithDropout.

private Object[][] getUnivariateGaussianTargetsWithDropout(final double sigma, final double dropoutRate) {
    Random rng = new Random(337);
    final RandomGenerator randomGenerator = RandomGeneratorFactory.createRandomGenerator(rng);
    NormalDistribution n = new NormalDistribution(randomGenerator, 1, sigma);
    final int numDataPoints = 10000;
    final int numEventPoints = 2000;
    // Randomly select dropoutRate of targets and reduce by 25%-75% (uniformly distributed)
    UniformRealDistribution uniformRealDistribution = new UniformRealDistribution(randomGenerator, 0, 1.0);
    final List<ReadCountRecord.SingleSampleRecord> targetList = new ArrayList<>();
    for (int i = 0; i < numDataPoints; i++) {
        double coverage = n.sample() + (i < (numDataPoints - numEventPoints) ? 0.0 : 0.5);
        if (uniformRealDistribution.sample() < dropoutRate) {
            double multiplier = .25 + uniformRealDistribution.sample() / 2;
            coverage = coverage * multiplier;
        }
        targetList.add(new ReadCountRecord.SingleSampleRecord(new Target("arbitrary_name", new SimpleInterval("chr1", 100 + 2 * i, 101 + 2 * i)), coverage));
    }
    HashedListTargetCollection<ReadCountRecord.SingleSampleRecord> targets = new HashedListTargetCollection<>(targetList);
    List<ModeledSegment> segments = new ArrayList<>();
    segments.add(new ModeledSegment(new SimpleInterval("chr1", 100, 16050), 8000, 1));
    segments.add(new ModeledSegment(new SimpleInterval("chr1", 16100, 20200), 2000, 1.5));
    return new Object[][] { { targets, segments } };
}
Also used : UniformRealDistribution(org.apache.commons.math3.distribution.UniformRealDistribution) ArrayList(java.util.ArrayList) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Random(java.util.Random) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Aggregations

NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)23 Random (java.util.Random)6 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)6 ArrayList (java.util.ArrayList)5 Test (org.testng.annotations.Test)5 AbstractRealDistribution (org.apache.commons.math3.distribution.AbstractRealDistribution)4 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)4 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)3 UniformRealDistribution (org.apache.commons.math3.distribution.UniformRealDistribution)3 VisibleForTesting (com.google.common.annotations.VisibleForTesting)2 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)2 MultivariateNormalMixtureExpectationMaximization (org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization)2 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)2 MaxCountExceededException (org.apache.commons.math3.exception.MaxCountExceededException)2 SingularMatrixException (org.apache.commons.math3.linear.SingularMatrixException)2 Mean (org.apache.commons.math3.stat.descriptive.moment.Mean)2 StandardDeviation (org.apache.commons.math3.stat.descriptive.moment.StandardDeviation)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 ValueType (io.druid.segment.column.ValueType)1 AbstractIntegerDistribution (org.apache.commons.math3.distribution.AbstractIntegerDistribution)1