Search in sources :

Example 41 with Mean

use of org.apache.commons.math3.stat.descriptive.moment.Mean in project incubator-systemml by apache.

the class ParameterizedBuiltin method computeFromDistribution.

/**
 * Helper function to compute distribution-specific cdf (both lowertail and uppertail) and inverse cdf.
 *
 * @param dcode probablility distribution code
 * @param params map of parameters
 * @param inverse true if inverse
 * @return cdf or inverse cdf
 */
private static double computeFromDistribution(ProbabilityDistributionCode dcode, HashMap<String, String> params, boolean inverse) {
    // given value is "quantile" when inverse=false, and it is "probability" when inverse=true
    double val = Double.parseDouble(params.get("target"));
    boolean lowertail = true;
    if (params.get("lower.tail") != null) {
        lowertail = Boolean.parseBoolean(params.get("lower.tail"));
    }
    AbstractRealDistribution distFunction = null;
    switch(dcode) {
        case NORMAL:
            // default values for mean and sd
            double mean = 0.0, sd = 1.0;
            String mean_s = params.get("mean"), sd_s = params.get("sd");
            if (mean_s != null)
                mean = Double.parseDouble(mean_s);
            if (sd_s != null)
                sd = Double.parseDouble(sd_s);
            if (sd <= 0)
                throw new DMLRuntimeException("Standard deviation for Normal distribution must be positive (" + sd + ")");
            distFunction = new NormalDistribution(mean, sd);
            break;
        case EXP:
            // default value for 1/mean or rate
            double exp_rate = 1.0;
            if (params.get("rate") != null)
                exp_rate = Double.parseDouble(params.get("rate"));
            if (exp_rate <= 0) {
                throw new DMLRuntimeException("Rate for Exponential distribution must be positive (" + exp_rate + ")");
            }
            // For exponential distribution: mean = 1/rate
            distFunction = new ExponentialDistribution(1.0 / exp_rate);
            break;
        case CHISQ:
            if (params.get("df") == null) {
                throw new DMLRuntimeException("" + "Degrees of freedom must be specified for chi-squared distribution " + "(e.g., q=qchisq(0.5, df=20); p=pchisq(target=q, df=1.2))");
            }
            int df = UtilFunctions.parseToInt(params.get("df"));
            if (df <= 0) {
                throw new DMLRuntimeException("Degrees of Freedom for chi-squared distribution must be positive (" + df + ")");
            }
            distFunction = new ChiSquaredDistribution(df);
            break;
        case F:
            if (params.get("df1") == null || params.get("df2") == null) {
                throw new DMLRuntimeException("" + "Degrees of freedom must be specified for F distribution " + "(e.g., q = qf(target=0.5, df1=20, df2=30); p=pf(target=q, df1=20, df2=30))");
            }
            int df1 = UtilFunctions.parseToInt(params.get("df1"));
            int df2 = UtilFunctions.parseToInt(params.get("df2"));
            if (df1 <= 0 || df2 <= 0) {
                throw new DMLRuntimeException("Degrees of Freedom for F distribution must be positive (" + df1 + "," + df2 + ")");
            }
            distFunction = new FDistribution(df1, df2);
            break;
        case T:
            if (params.get("df") == null) {
                throw new DMLRuntimeException("" + "Degrees of freedom is needed to compute probabilities from t distribution " + "(e.g., q = qt(target=0.5, df=10); p = pt(target=q, df=10))");
            }
            int t_df = UtilFunctions.parseToInt(params.get("df"));
            if (t_df <= 0) {
                throw new DMLRuntimeException("Degrees of Freedom for t distribution must be positive (" + t_df + ")");
            }
            distFunction = new TDistribution(t_df);
            break;
        default:
            throw new DMLRuntimeException("Invalid distribution code: " + dcode);
    }
    double ret = Double.NaN;
    if (inverse) {
        // inverse cdf
        ret = distFunction.inverseCumulativeProbability(val);
    } else if (lowertail) {
        // cdf (lowertail)
        ret = distFunction.cumulativeProbability(val);
    } else {
        // cdf (upper tail)
        // TODO: more accurate distribution-specific computation of upper tail probabilities
        ret = 1.0 - distFunction.cumulativeProbability(val);
    }
    return ret;
}
Also used : AbstractRealDistribution(org.apache.commons.math3.distribution.AbstractRealDistribution) ChiSquaredDistribution(org.apache.commons.math3.distribution.ChiSquaredDistribution) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) ExponentialDistribution(org.apache.commons.math3.distribution.ExponentialDistribution) TDistribution(org.apache.commons.math3.distribution.TDistribution) FDistribution(org.apache.commons.math3.distribution.FDistribution) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 42 with Mean

use of org.apache.commons.math3.stat.descriptive.moment.Mean in project incubator-systemml by apache.

the class PoissonPRNGenerator method setup.

public void setup(double mean, long sd) {
    seed = sd;
    SynchronizedRandomGenerator srg = new SynchronizedRandomGenerator(new Well1024a());
    srg.setSeed(seed);
    _pdist = new PoissonDistribution(srg, _mean, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) SynchronizedRandomGenerator(org.apache.commons.math3.random.SynchronizedRandomGenerator) Well1024a(org.apache.commons.math3.random.Well1024a)

Example 43 with Mean

use of org.apache.commons.math3.stat.descriptive.moment.Mean in project deeplearning4j by deeplearning4j.

the class TestReconstructionDistributions method testExponentialLogProb.

@Test
public void testExponentialLogProb() {
    Nd4j.getRandom().setSeed(12345);
    int inputSize = 4;
    int[] mbs = new int[] { 1, 2, 5 };
    Random r = new Random(12345);
    for (boolean average : new boolean[] { true, false }) {
        for (int minibatch : mbs) {
            INDArray x = Nd4j.zeros(minibatch, inputSize);
            for (int i = 0; i < minibatch; i++) {
                for (int j = 0; j < inputSize; j++) {
                    x.putScalar(i, j, r.nextInt(2));
                }
            }
            //i.e., pre-afn gamma
            INDArray distributionParams = Nd4j.rand(minibatch, inputSize).muli(2).subi(1);
            INDArray gammas = Transforms.tanh(distributionParams, true);
            ReconstructionDistribution dist = new ExponentialReconstructionDistribution("tanh");
            double negLogProb = dist.negLogProbability(x, distributionParams, average);
            INDArray exampleNegLogProb = dist.exampleNegLogProbability(x, distributionParams);
            assertArrayEquals(new int[] { minibatch, 1 }, exampleNegLogProb.shape());
            //Calculate the same thing, but using Apache Commons math
            double logProbSum = 0.0;
            for (int i = 0; i < minibatch; i++) {
                double exampleSum = 0.0;
                for (int j = 0; j < inputSize; j++) {
                    double gamma = gammas.getDouble(i, j);
                    double lambda = Math.exp(gamma);
                    double mean = 1.0 / lambda;
                    //Commons math uses mean = 1/lambda
                    ExponentialDistribution exp = new ExponentialDistribution(mean);
                    double xVal = x.getDouble(i, j);
                    double thisLogProb = exp.logDensity(xVal);
                    logProbSum += thisLogProb;
                    exampleSum += thisLogProb;
                }
                assertEquals(-exampleNegLogProb.getDouble(i), exampleSum, 1e-6);
            }
            double expNegLogProb;
            if (average) {
                expNegLogProb = -logProbSum / minibatch;
            } else {
                expNegLogProb = -logProbSum;
            }
            //                System.out.println(x);
            //                System.out.println(expNegLogProb + "\t" + logProb + "\t" + (logProb / expNegLogProb));
            assertEquals(expNegLogProb, negLogProb, 1e-6);
            //Also: check random sampling...
            int count = minibatch * inputSize;
            INDArray arr = Nd4j.linspace(-3, 3, count).reshape(minibatch, inputSize);
            INDArray sampleMean = dist.generateAtMean(arr);
            INDArray sampleRandom = dist.generateRandom(arr);
            for (int i = 0; i < minibatch; i++) {
                for (int j = 0; j < inputSize; j++) {
                    double d1 = sampleMean.getDouble(i, j);
                    double d2 = sampleRandom.getDouble(i, j);
                    assertTrue(d1 >= 0.0);
                    assertTrue(d2 >= 0.0);
                }
            }
        }
    }
}
Also used : Random(java.util.Random) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ExponentialReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.ExponentialReconstructionDistribution) ExponentialDistribution(org.apache.commons.math3.distribution.ExponentialDistribution) GaussianReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution) ReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution) ExponentialReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.ExponentialReconstructionDistribution) BernoulliReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution) Test(org.junit.Test)

Example 44 with Mean

use of org.apache.commons.math3.stat.descriptive.moment.Mean in project recordinality by cscotta.

the class RecordinalityTest method buildRun.

private Callable<Result> buildRun(final int kSize, final int numRuns, final List<String> lines) {
    return new Callable<Result>() {

        public Result call() throws Exception {
            long start = System.currentTimeMillis();
            final double[] results = new double[numRuns];
            for (int i = 0; i < numRuns; i++) {
                Recordinality rec = new Recordinality(kSize);
                for (String line : lines) rec.observe(line);
                results[i] = rec.estimateCardinality();
            }
            double mean = new Mean().evaluate(results);
            double stdDev = new StandardDeviation().evaluate(results);
            double stdError = stdDev / 3193;
            long runTime = System.currentTimeMillis() - start;
            return new Result(kSize, mean, stdError, runTime);
        }
    };
}
Also used : Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) StandardDeviation(org.apache.commons.math3.stat.descriptive.moment.StandardDeviation)

Example 45 with Mean

use of org.apache.commons.math3.stat.descriptive.moment.Mean in project presto by prestodb.

the class AbstractTestQueries method testTableSampleBernoulli.

@Test
public void testTableSampleBernoulli() {
    DescriptiveStatistics stats = new DescriptiveStatistics();
    int total = computeExpected("SELECT orderkey FROM orders", ImmutableList.of(BIGINT)).getMaterializedRows().size();
    for (int i = 0; i < 100; i++) {
        List<MaterializedRow> values = computeActual("SELECT orderkey FROM ORDERS TABLESAMPLE BERNOULLI (50)").getMaterializedRows();
        assertEquals(values.size(), ImmutableSet.copyOf(values).size(), "TABLESAMPLE produced duplicate rows");
        stats.addValue(values.size() * 1.0 / total);
    }
    double mean = stats.getGeometricMean();
    assertTrue(mean > 0.45 && mean < 0.55, format("Expected mean sampling rate to be ~0.5, but was %s", mean));
}
Also used : DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) MaterializedRow(com.facebook.presto.testing.MaterializedRow) Test(org.testng.annotations.Test)

Aggregations

Test (org.testng.annotations.Test)27 Mean (org.apache.commons.math3.stat.descriptive.moment.Mean)23 List (java.util.List)17 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)16 RealMatrix (org.apache.commons.math3.linear.RealMatrix)14 ArrayList (java.util.ArrayList)12 Collectors (java.util.stream.Collectors)12 StandardDeviation (org.apache.commons.math3.stat.descriptive.moment.StandardDeviation)12 Utils (org.broadinstitute.hellbender.utils.Utils)12 StoredDataStatistics (gdsc.core.utils.StoredDataStatistics)10 Arrays (java.util.Arrays)10 IntStream (java.util.stream.IntStream)10 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)10 WeightedObservedPoint (org.apache.commons.math3.fitting.WeightedObservedPoint)10 Logger (org.apache.logging.log4j.Logger)10 ReadCountCollection (org.broadinstitute.hellbender.tools.exome.ReadCountCollection)10 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)10 BaseTest (org.broadinstitute.hellbender.utils.test.BaseTest)10 Function (java.util.function.Function)9 DescriptiveStatistics (org.apache.commons.math3.stat.descriptive.DescriptiveStatistics)9