Search in sources :

Example 16 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project shifu by ShifuML.

the class WDLWorker method init.

@SuppressWarnings({ "unchecked", "unused" })
@Override
public void init(WorkerContext<WDLParams, WDLParams> context) {
    Properties props = context.getProps();
    try {
        SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    this.initCateIndexMap();
    this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }
    this.poissonSampler = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER));
    this.rng = new PoissonDistribution(1d);
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (upSampleWeight != 1d && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
    double validationRate = this.modelConfig.getValidSetRate();
    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
        this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
    } else {
        if (validationRate != 0d) {
            this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>());
            this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>());
        } else {
            this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
        }
    }
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    // numerical + categorical = # of all input
    this.numInputs = inputOutputIndex[0];
    this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
    // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
    // 1, with index of 0,1,2,3 denotes different classes
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    this.validParams = this.modelConfig.getTrain().getParams();
    // Build wide and deep graph
    List<Integer> embedColumnIds = (List<Integer>) this.validParams.get(CommonConstants.NUM_EMBED_COLUMN_IDS);
    Integer embedOutputs = (Integer) this.validParams.get(CommonConstants.NUM_EMBED_OUTPUTS);
    List<Integer> embedOutputList = new ArrayList<Integer>();
    for (Integer cId : embedColumnIds) {
        embedOutputList.add(embedOutputs == null ? CommonConstants.DEFAULT_EMBEDING_OUTPUT : embedOutputs);
    }
    List<Integer> numericalIds = DTrainUtils.getNumericalIds(this.columnConfigList, isAfterVarSelect);
    List<Integer> wideColumnIds = DTrainUtils.getCategoricalIds(columnConfigList, isAfterVarSelect);
    Map<Integer, Integer> idBinCateSizeMap = DTrainUtils.getIdBinCategorySizeMap(columnConfigList);
    int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodes = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES);
    Float l2reg = ((Double) this.validParams.get(CommonConstants.WDL_L2_REG)).floatValue();
    this.wnd = new WideAndDeep(idBinCateSizeMap, numInputs, numericalIds, embedColumnIds, embedOutputList, wideColumnIds, hiddenNodes, actFunc, l2reg);
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) IOException(java.io.IOException) MemoryLimitedList(ml.shifu.guagua.util.MemoryLimitedList)

Example 17 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class CreateData method run.

/*
	 * (non-Javadoc)
	 * 
	 * @see ij.plugin.PlugIn#run(java.lang.String)
	 */
public void run(String arg) {
    SMLMUsageTracker.recordPlugin(this.getClass(), arg);
    extraOptions = Utils.isExtraOptions();
    simpleMode = (arg != null && arg.contains("simple"));
    benchmarkMode = (arg != null && arg.contains("benchmark"));
    spotMode = (arg != null && arg.contains("spot"));
    trackMode = (arg != null && arg.contains("track"));
    if ("load".equals(arg)) {
        loadBenchmarkData();
        return;
    }
    // Each localisation is a simulated emission of light from a point in space and time
    List<LocalisationModel> localisations = null;
    // Each localisation set is a collection of localisations that represent all localisations
    // with the same ID that are on in the same image time frame (Note: the simulation
    // can create many localisations per fluorophore per time frame which is useful when 
    // modelling moving particles)
    List<LocalisationModelSet> localisationSets = null;
    // Each fluorophore contains the on and off times when light was emitted 
    List<? extends FluorophoreSequenceModel> fluorophores = null;
    if (simpleMode || benchmarkMode || spotMode) {
        if (!showSimpleDialog())
            return;
        resetMemory();
        // 1 second frames
        settings.exposureTime = 1000;
        areaInUm = settings.size * settings.pixelPitch * settings.size * settings.pixelPitch / 1e6;
        // Number of spots per frame
        int n = 0;
        int[] nextN = null;
        SpatialDistribution dist;
        if (benchmarkMode) {
            // --------------------
            // BENCHMARK SIMULATION
            // --------------------
            // Draw the same point on the image repeatedly
            n = 1;
            dist = createFixedDistribution();
            reportAndSaveFittingLimits(dist);
        } else if (spotMode) {
            // ---------------
            // SPOT SIMULATION
            // ---------------
            // The spot simulation draws 0 or 1 random point per frame. 
            // Ensure we have 50% of the frames with a spot.
            nextN = new int[settings.particles * 2];
            Arrays.fill(nextN, 0, settings.particles, 1);
            Random rand = new Random();
            rand.shuffle(nextN);
            // Only put spots in the central part of the image
            double border = settings.size / 4.0;
            dist = createUniformDistribution(border);
        } else {
            // -----------------
            // SIMPLE SIMULATION
            // -----------------
            // The simple simulation draws n random points per frame to achieve a specified density.
            // No points will appear in multiple frames.
            // Each point has a random number of photons sampled from a range.
            // We can optionally use a mask. Create his first as it updates the areaInUm
            dist = createDistribution();
            // Randomly sample (i.e. not uniform density in all frames)
            if (settings.samplePerFrame) {
                final double mean = areaInUm * settings.density;
                System.out.printf("Mean samples = %f\n", mean);
                if (mean < 0.5) {
                    GenericDialog gd = new GenericDialog(TITLE);
                    gd.addMessage("The mean samples per frame is low: " + Utils.rounded(mean) + "\n \nContinue?");
                    gd.enableYesNoCancel();
                    gd.hideCancelButton();
                    gd.showDialog();
                    if (!gd.wasOKed())
                        return;
                }
                PoissonDistribution poisson = new PoissonDistribution(createRandomGenerator(), mean, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
                StoredDataStatistics samples = new StoredDataStatistics(settings.particles);
                while (samples.getSum() < settings.particles) {
                    samples.add(poisson.sample());
                }
                nextN = new int[samples.getN()];
                for (int i = 0; i < nextN.length; i++) nextN[i] = (int) samples.getValue(i);
            } else {
                // Use the density to get the number per frame
                n = (int) FastMath.max(1, Math.round(areaInUm * settings.density));
            }
        }
        RandomGenerator random = null;
        localisations = new ArrayList<LocalisationModel>(settings.particles);
        localisationSets = new ArrayList<LocalisationModelSet>(settings.particles);
        final int minPhotons = (int) settings.photonsPerSecond;
        final int range = (int) settings.photonsPerSecondMaximum - minPhotons + 1;
        if (range > 1)
            random = createRandomGenerator();
        // Add frames at the specified density until the number of particles has been reached
        int id = 0;
        int t = 0;
        while (id < settings.particles) {
            // Allow the number per frame to be specified
            if (nextN != null) {
                if (t >= nextN.length)
                    break;
                n = nextN[t];
            }
            // Simulate random positions in the frame for the specified density
            t++;
            for (int j = 0; j < n; j++) {
                final double[] xyz = dist.next();
                // Ignore within border. We do not want to draw things we cannot fit.
                //if (!distBorder.isWithinXY(xyz))
                //	continue;
                // Simulate random photons
                final int intensity = minPhotons + ((random != null) ? random.nextInt(range) : 0);
                LocalisationModel m = new LocalisationModel(id, t, xyz, intensity, LocalisationModel.CONTINUOUS);
                localisations.add(m);
                // Each localisation can be a separate localisation set
                LocalisationModelSet set = new LocalisationModelSet(id, t);
                set.add(m);
                localisationSets.add(set);
                id++;
            }
        }
    } else {
        if (!showDialog())
            return;
        resetMemory();
        areaInUm = settings.size * settings.pixelPitch * settings.size * settings.pixelPitch / 1e6;
        int totalSteps;
        double correlation = 0;
        ImageModel imageModel;
        if (trackMode) {
            // ----------------
            // TRACK SIMULATION
            // ----------------
            // In track mode we create fixed lifetime fluorophores that do not overlap in time.
            // This is the simplest simulation to test moving molecules.
            settings.seconds = (int) Math.ceil(settings.particles * (settings.exposureTime + settings.tOn) / 1000);
            totalSteps = 0;
            final double simulationStepsPerFrame = (settings.stepsPerSecond * settings.exposureTime) / 1000.0;
            imageModel = new FixedLifetimeImageModel(settings.stepsPerSecond * settings.tOn / 1000.0, simulationStepsPerFrame);
        } else {
            // ---------------
            // FULL SIMULATION
            // ---------------
            // The full simulation draws n random points in space.
            // The same molecule may appear in multiple frames, move and blink.
            //
            // Points are modelled as fluorophores that must be activated and then will 
            // blink and photo-bleach. The molecules may diffuse and this can be simulated 
            // with many steps per image frame. All steps from a frame are collected
            // into a localisation set which can be drawn on the output image.
            SpatialIllumination activationIllumination = createIllumination(settings.pulseRatio, settings.pulseInterval);
            // Generate additional frames so that each frame has the set number of simulation steps
            totalSteps = (int) Math.ceil(settings.seconds * settings.stepsPerSecond);
            // Since we have an exponential decay of activations
            // ensure half of the particles have activated by 30% of the frames.
            double eAct = totalSteps * 0.3 * activationIllumination.getAveragePhotons();
            // Q. Does tOn/tOff change depending on the illumination strength?
            imageModel = new ActivationEnergyImageModel(eAct, activationIllumination, settings.stepsPerSecond * settings.tOn / 1000.0, settings.stepsPerSecond * settings.tOffShort / 1000.0, settings.stepsPerSecond * settings.tOffLong / 1000.0, settings.nBlinksShort, settings.nBlinksLong);
            imageModel.setUseGeometricDistribution(settings.nBlinksGeometricDistribution);
            // Only use the correlation if selected for the distribution
            if (PHOTON_DISTRIBUTION[PHOTON_CORRELATED].equals(settings.photonDistribution))
                correlation = settings.correlation;
        }
        imageModel.setRandomGenerator(createRandomGenerator());
        imageModel.setPhotonBudgetPerFrame(true);
        imageModel.setDiffusion2D(settings.diffuse2D);
        imageModel.setRotation2D(settings.rotate2D);
        IJ.showStatus("Creating molecules ...");
        SpatialDistribution distribution = createDistribution();
        List<CompoundMoleculeModel> compounds = createCompoundMolecules();
        if (compounds == null)
            return;
        List<CompoundMoleculeModel> molecules = imageModel.createMolecules(compounds, settings.particles, distribution, settings.rotateInitialOrientation);
        // Activate fluorophores
        IJ.showStatus("Creating fluorophores ...");
        // Note: molecules list will be converted to compounds containing fluorophores
        fluorophores = imageModel.createFluorophores(molecules, totalSteps);
        if (fluorophores.isEmpty()) {
            IJ.error(TITLE, "No fluorophores created");
            return;
        }
        IJ.showStatus("Creating localisations ...");
        // TODO - Output a molecule Id for each fluorophore if using compound molecules. This allows analysis
        // of the ratio of trimers, dimers, monomers, etc that could be detected.
        totalSteps = checkTotalSteps(totalSteps, fluorophores);
        if (totalSteps == 0)
            return;
        imageModel.setPhotonDistribution(createPhotonDistribution());
        imageModel.setConfinementDistribution(createConfinementDistribution());
        // This should be optimised
        imageModel.setConfinementAttempts(10);
        localisations = imageModel.createImage(molecules, settings.fixedFraction, totalSteps, (double) settings.photonsPerSecond / settings.stepsPerSecond, correlation, settings.rotateDuringSimulation);
        // Re-adjust the fluorophores to the correct time
        if (settings.stepsPerSecond != 1) {
            final double scale = 1.0 / settings.stepsPerSecond;
            for (FluorophoreSequenceModel f : fluorophores) f.adjustTime(scale);
        }
        // Integrate the frames
        localisationSets = combineSimulationSteps(localisations);
        localisationSets = filterToImageBounds(localisationSets);
    }
    datasetNumber++;
    localisations = drawImage(localisationSets);
    if (localisations == null || localisations.isEmpty()) {
        IJ.error(TITLE, "No localisations created");
        return;
    }
    fluorophores = removeFilteredFluorophores(fluorophores, localisations);
    double signalPerFrame = showSummary(fluorophores, localisations);
    if (!benchmarkMode) {
        boolean fullSimulation = (!(simpleMode || spotMode));
        saveSimulationParameters(localisations.size(), fullSimulation, signalPerFrame);
    }
    IJ.showStatus("Saving data ...");
    //convertRelativeToAbsolute(molecules);
    saveFluorophores(fluorophores);
    saveImageResults(results);
    saveLocalisations(localisations);
    // The settings for the filenames may have changed
    SettingsManager.saveSettings(globalSettings);
    IJ.showStatus("Done");
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) ActivationEnergyImageModel(gdsc.smlm.model.ActivationEnergyImageModel) CompoundMoleculeModel(gdsc.smlm.model.CompoundMoleculeModel) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Random(gdsc.core.utils.Random) GenericDialog(ij.gui.GenericDialog) SpatialIllumination(gdsc.smlm.model.SpatialIllumination) SpatialDistribution(gdsc.smlm.model.SpatialDistribution) StoredDataStatistics(gdsc.core.utils.StoredDataStatistics) LocalisationModel(gdsc.smlm.model.LocalisationModel) FluorophoreSequenceModel(gdsc.smlm.model.FluorophoreSequenceModel) FixedLifetimeImageModel(gdsc.smlm.model.FixedLifetimeImageModel) LocalisationModelSet(gdsc.smlm.model.LocalisationModelSet) ImageModel(gdsc.smlm.model.ImageModel) ActivationEnergyImageModel(gdsc.smlm.model.ActivationEnergyImageModel) FixedLifetimeImageModel(gdsc.smlm.model.FixedLifetimeImageModel)

Example 18 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class GradientCalculatorSpeedTest method mleGradientCalculatorComputesLikelihood.

@Test
public void mleGradientCalculatorComputesLikelihood() {
    //@formatter:off
    NonLinearFunction func = new NonLinearFunction() {

        double u;

        public void initialise(double[] a) {
            u = a[0];
        }

        public int[] gradientIndices() {
            return null;
        }

        public double eval(int x, double[] dyda) {
            return 0;
        }

        public double eval(int x) {
            return u;
        }

        public double eval(int x, double[] dyda, double[] w) {
            return 0;
        }

        public double evalw(int x, double[] w) {
            return 0;
        }

        public boolean canComputeWeights() {
            return false;
        }

        public int getNumberOfGradients() {
            return 0;
        }
    };
    //@formatter:on
    int[] xx = Utils.newArray(100, 0, 1);
    double[] xxx = Utils.newArray(100, 0, 1.0);
    for (double u : new double[] { 0.79, 2.5, 5.32 }) {
        double ll = 0;
        PoissonDistribution pd = new PoissonDistribution(u);
        for (int x : xx) {
            double o = MLEGradientCalculator.likelihood(u, x);
            double e = pd.probability(x);
            Assert.assertEquals("likelihood", e, o, e * 1e-10);
            o = MLEGradientCalculator.logLikelihood(u, x);
            e = pd.logProbability(x);
            Assert.assertEquals("log likelihood", e, o, Math.abs(e) * 1e-10);
            ll += e;
        }
        MLEGradientCalculator gc = new MLEGradientCalculator(1);
        double o = gc.logLikelihood(xxx, new double[] { u }, func);
        Assert.assertEquals("sum log likelihood", ll, o, Math.abs(ll) * 1e-10);
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) NonLinearFunction(gdsc.smlm.function.NonLinearFunction) Test(org.junit.Test)

Example 19 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class PoissonCalculatorTest method cumulativeProbabilityIsOneWithRealDataForCountAbove4.

@Test
public void cumulativeProbabilityIsOneWithRealDataForCountAbove4() {
    for (double mu : photons) {
        // Determine upper limit for a Poisson
        double max = new PoissonDistribution(mu).inverseCumulativeProbability(P_LIMIT);
        // Determine lower limit
        double sd = Math.sqrt(mu);
        double min = (int) Math.max(0, mu - 4 * sd);
        cumulativeProbabilityIsOneWithRealData(mu, min, max, mu >= 4);
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) Test(org.junit.Test)

Example 20 with PoissonDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution in project GDSC-SMLM by aherbert.

the class SCMOSLikelihoodWrapperTest method cumulativeProbabilityIsOneWithRealDataForCountAbove8.

@Test
public void cumulativeProbabilityIsOneWithRealDataForCountAbove8() {
    for (double mu : photons) {
        // Determine upper limit for a Poisson
        double max = new PoissonDistribution(mu).inverseCumulativeProbability(P_LIMIT);
        // Determine lower limit
        double sd = Math.sqrt(mu);
        double min = (int) Math.max(0, mu - 4 * sd);
        // Map to observed values using the gain and offset
        max = max * G + O;
        min = min * G + O;
        cumulativeProbabilityIsOneWithRealData(mu, min, max, mu > 8);
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) Test(org.junit.Test)

Aggregations

PoissonDistribution (org.apache.commons.math3.distribution.PoissonDistribution)39 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)7 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)5 Test (org.junit.jupiter.api.Test)4 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 Random (java.util.Random)3 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)3 Test (org.junit.Test)3 SqlScalarFunction (com.facebook.presto.metadata.SqlScalarFunction)2 Description (com.facebook.presto.spi.function.Description)2 ScalarFunction (com.facebook.presto.spi.function.ScalarFunction)2 SqlType (com.facebook.presto.spi.function.SqlType)2 TDigest.createTDigest (com.facebook.presto.tdigest.TDigest.createTDigest)2 DecimalOperators.modulusScalarFunction (com.facebook.presto.type.DecimalOperators.modulusScalarFunction)2 Sets (com.google.cloud.dataflow.sdk.repackaged.com.google.common.collect.Sets)2 java.util (java.util)2 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)2 GridSearch (ml.shifu.shifu.core.dtrain.gs.GridSearch)2 PoissonDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.PoissonDistribution)2