Search in sources :

Example 11 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class PoissonCalculatorTest method canComputeFastLikelihoodForIntegerData.

@Test
void canComputeFastLikelihoodForIntegerData() {
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-4, 0);
    for (final double u : photons) {
        final PoissonDistribution pd = new PoissonDistribution(u);
        for (int x = 0; x < 100; x++) {
            double expected = pd.probability(x);
            double observed = PoissonCalculator.fastLikelihood(u, x);
            if (expected > 1e-100) {
                TestAssertions.assertTest(expected, observed, predicate);
            }
            expected = pd.logProbability(x);
            observed = PoissonCalculator.fastLogLikelihood(u, x);
            TestAssertions.assertTest(expected, observed, predicate);
        }
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Example 12 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class PoissonFunctionTest method probabilityMatchesPoissonWithNoGain.

private static void probabilityMatchesPoissonWithNoGain(final double mu) {
    final double o = mu;
    final PoissonFunction f = new PoissonFunction(1.0);
    final PoissonDistribution pd = new PoissonDistribution(mu);
    final double p = 0;
    final int[] range = getRange(1, mu);
    final int min = range[0];
    final int max = range[1];
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-8, 0);
    for (int x = min; x <= max; x++) {
        final double v1 = f.likelihood(x, o);
        final double v2 = pd.probability(x);
        TestAssertions.assertTest(v1, v2, predicate, FunctionUtils.getSupplier("g=%f, mu=%f, x=%d", gain, mu, x));
    }
}
Also used : PoissonDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)

Example 13 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class ScmosLikelihoodWrapperTest method instanceLikelihoodMatches.

private static void instanceLikelihoodMatches(final double mu, boolean test) {
    // Determine upper limit for a Poisson
    final int limit = new PoissonDistribution(mu).inverseCumulativeProbability(P_LIMIT);
    // Map to observed values using the gain and offset
    final double max = limit * G;
    final double step = 0.1;
    final int n = (int) Math.ceil(max / step);
    // Evaluate all values from (zero+offset) to large n
    final double[] k = SimpleArrayUtils.newArray(n, O, step);
    final double[] a = new double[0];
    final double[] gradient = new double[0];
    final float[] var = newArray(n, VAR);
    final float[] g = newArray(n, G);
    final float[] o = newArray(n, O);
    final NonLinearFunction nlf = new NonLinearFunction() {

        @Override
        public void initialise(double[] a) {
        // Ignore
        }

        @Override
        public int[] gradientIndices() {
            return new int[0];
        }

        @Override
        public double evalw(int x, double[] dyda, double[] weight) {
            return 0;
        }

        @Override
        public double evalw(int x, double[] weight) {
            return 0;
        }

        @Override
        public double eval(int x) {
            return mu;
        }

        @Override
        public double eval(int x, double[] dyda) {
            return mu;
        }

        @Override
        public boolean canComputeWeights() {
            return false;
        }

        @Override
        public int getNumberOfGradients() {
            return 0;
        }
    };
    ScmosLikelihoodWrapper func = new ScmosLikelihoodWrapper(nlf, a, k, n, var, g, o);
    final IntArrayFormatSupplier msg1 = new IntArrayFormatSupplier("computeLikelihood @ %d", 1);
    final IntArrayFormatSupplier msg2 = new IntArrayFormatSupplier("computeLikelihood+gradient @ %d", 1);
    double total = 0;
    double pvalue = 0;
    double maxp = 0;
    int maxi = 0;
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-10, 0);
    for (int i = 0; i < n; i++) {
        final double nll = func.computeLikelihood(i);
        final double nll2 = func.computeLikelihood(gradient, i);
        final double nll3 = ScmosLikelihoodWrapper.negativeLogLikelihood(mu, var[i], g[i], o[i], k[i]);
        total += nll;
        TestAssertions.assertTest(nll3, nll, predicate, msg1.set(0, i));
        TestAssertions.assertTest(nll3, nll2, predicate, msg2.set(0, i));
        final double pp = StdMath.exp(-nll);
        if (maxp < pp) {
            maxp = pp;
            maxi = i;
        // TestLog.fine(logger,"mu=%f, e=%f, k=%f, pp=%f", mu, mu * G + O, k[i], pp);
        }
        pvalue += pp * step;
    }
    // Expected max of the distribution is the mode of the Poisson distribution.
    // This has two modes for integer input counts. We take the mean of those.
    // https://en.wikipedia.org/wiki/Poisson_distribution
    // Note that the shift of VAR/(G*G) is a constant applied to both the expected and
    // observed values and consequently cancels when predicting the max, i.e. we add
    // a constant count to the observed values and shift the distribution by the same
    // constant. We can thus compute the mode for the unshifted distribution.
    final double lambda = mu;
    final double mode1 = Math.floor(lambda);
    final double mode2 = Math.ceil(lambda) - 1;
    // Scale to observed values
    final double kmax = ((mode1 + mode2) * 0.5) * G + O;
    // TestLog.fine(logger,"mu=%f, p=%f, maxp=%f @ %f (expected=%f %f)", mu, p, maxp, k[maxi], kmax,
    // kmax - k[maxi]);
    TestAssertions.assertTest(kmax, k[maxi], TestHelper.doublesAreClose(1e-3, 0), "k-max");
    if (test) {
        Assertions.assertEquals(P_LIMIT, pvalue, 0.02, () -> "mu=" + mu);
    }
    // Check the function can compute the same total
    double sum;
    double sum2;
    sum = func.computeLikelihood();
    sum2 = func.computeLikelihood(gradient);
    TestAssertions.assertTest(total, sum, predicate, "computeLikelihood");
    TestAssertions.assertTest(total, sum2, predicate, "computeLikelihood with gradient");
    // Check the function can compute the same total after duplication
    func = func.build(nlf, a);
    sum = func.computeLikelihood();
    sum2 = func.computeLikelihood(gradient);
    TestAssertions.assertTest(total, sum, predicate, "computeLikelihood");
    TestAssertions.assertTest(total, sum2, predicate, "computeLikelihood with gradient");
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) IntArrayFormatSupplier(uk.ac.sussex.gdsc.test.utils.functions.IntArrayFormatSupplier)

Example 14 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project shifu by ShifuML.

the class LogisticRegressionWorker method init.

@Override
public void init(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
    loadConfigFiles(context.getProps());
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    this.inputNum = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    this.outputNum = inputOutputIndex[1];
    this.candidateNum = inputOutputIndex[2];
    this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
    }
    if (this.inputNum == 0) {
        throw new IllegalStateException("No any variables are selected, please try variable select step firstly.");
    }
    this.rng = new PoissonDistribution(1.0d);
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0) {
        // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
    double crossValidationRate = this.modelConfig.getValidSetRate();
    String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", "tmp");
    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new BytableMemoryDiskList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), tmpFolder + File.separator + "train-" + System.currentTimeMillis(), Data.class.getName());
        this.validationData = new BytableMemoryDiskList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), tmpFolder + File.separator + "test-" + System.currentTimeMillis(), Data.class.getName());
    } else {
        this.trainingData = new BytableMemoryDiskList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - crossValidationRate)), tmpFolder + File.separator + "train-" + System.currentTimeMillis(), Data.class.getName());
        this.validationData = new BytableMemoryDiskList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * crossValidationRate), tmpFolder + File.separator + "test-" + System.currentTimeMillis(), Data.class.getName());
    }
    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
    // cannot find a good place to close these two data set, using Shutdown hook
    Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {

        @Override
        public void run() {
            LogisticRegressionWorker.this.validationData.close();
            LogisticRegressionWorker.this.trainingData.close();
        }
    }));
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution)

Example 15 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project shifu by ShifuML.

the class LogisticRegressionWorker method sampleWeights.

protected float sampleWeights(float label) {
    float sampleWeights = 1f;
    // sample negative or kFoldCV, sample rate is 1d
    double sampleRate = (modelConfig.getTrain().getSampleNegOnly() || this.isKFoldCV) ? 1d : modelConfig.getTrain().getBaggingSampleRate();
    int classValue = (int) (label + 0.01f);
    if (!modelConfig.isBaggingWithReplacement()) {
        Random random = null;
        if (this.isStratifiedSampling) {
            random = baggingRandomMap.get(classValue);
            if (random == null) {
                random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                baggingRandomMap.put(classValue, random);
            }
        } else {
            random = baggingRandomMap.get(0);
            if (random == null) {
                random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                baggingRandomMap.put(0, random);
            }
        }
        if (random.nextDouble() <= sampleRate) {
            sampleWeights = 1f;
        } else {
            sampleWeights = 0f;
        }
    } else {
        // replacement
        if (this.isStratifiedSampling) {
            PoissonDistribution rng = this.baggingRngMap.get(classValue);
            if (rng == null) {
                rng = new PoissonDistribution(sampleRate);
                this.baggingRngMap.put(classValue, rng);
            }
            sampleWeights = rng.sample();
        } else {
            PoissonDistribution rng = this.baggingRngMap.get(0);
            if (rng == null) {
                rng = new PoissonDistribution(sampleRate);
                this.baggingRngMap.put(0, rng);
            }
            sampleWeights = rng.sample();
        }
    }
    return sampleWeights;
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) Random(java.util.Random)

Aggregations

PoissonDistribution (org.apache.commons.math3.distribution.PoissonDistribution)39 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)7 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)5 Test (org.junit.jupiter.api.Test)4 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 Random (java.util.Random)3 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)3 Test (org.junit.Test)3 SqlScalarFunction (com.facebook.presto.metadata.SqlScalarFunction)2 Description (com.facebook.presto.spi.function.Description)2 ScalarFunction (com.facebook.presto.spi.function.ScalarFunction)2 SqlType (com.facebook.presto.spi.function.SqlType)2 TDigest.createTDigest (com.facebook.presto.tdigest.TDigest.createTDigest)2 DecimalOperators.modulusScalarFunction (com.facebook.presto.type.DecimalOperators.modulusScalarFunction)2 Sets (com.google.cloud.dataflow.sdk.repackaged.com.google.common.collect.Sets)2 java.util (java.util)2 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)2 GridSearch (ml.shifu.shifu.core.dtrain.gs.GridSearch)2 PoissonDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution)2