use of org.apache.commons.math3.util.FastMath.PI in project GDSC-SMLM by aherbert.
the class PoissonGammaGaussianFunction method likelihood.
/**
* Compute the likelihood
*
* @param o
* The observed count
* @param e
* The expected count
* @return The likelihood
*/
public double likelihood(final double o, final double e) {
// Use the same variables as the Python code
final double cij = o;
// convert to photons
final double eta = alpha * e;
if (sigma == 0) {
// No convolution with a Gaussian. Simply evaluate for a Poisson-Gamma distribution.
final double p;
// Any observed count above zero
if (cij > 0.0) {
// The observed count converted to photons
final double nij = alpha * cij;
// The limit on eta * nij is therefore (709/2)^2 = 125670.25
if (eta * nij > 10000) {
// Approximate Bessel function i1(x) when using large x:
// i1(x) ~ exp(x)/sqrt(2*pi*x)
// However the entire equation is logged (creating transform),
// evaluated then raised to e to prevent overflow error on
// large exp(x)
final double transform = 0.5 * Math.log(alpha * eta / cij) - nij - eta + 2 * Math.sqrt(eta * nij) - Math.log(twoSqrtPi * Math.pow(eta * nij, 0.25));
p = FastMath.exp(transform);
} else {
// Second part of equation 135
p = Math.sqrt(alpha * eta / cij) * FastMath.exp(-nij - eta) * Bessel.I1(2 * Math.sqrt(eta * nij));
}
} else if (cij == 0.0) {
p = FastMath.exp(-eta);
} else {
p = 0;
}
return (p > minimumProbability) ? p : minimumProbability;
} else if (useApproximation) {
return mortensenApproximation(cij, eta);
} else {
// This code is the full evaluation of equation 7 from the supplementary information
// of the paper Chao, et al (2013) Nature Methods 10, 335-338.
// It is the full evaluation of a Poisson-Gamma-Gaussian convolution PMF.
// Read noise
final double sk = sigma;
// Gain
final double g = 1.0 / alpha;
// Observed pixel value
final double z = o;
// Expected number of photons
final double vk = eta;
// Compute the integral to infinity of:
// exp( -((z-u)/(sqrt(2)*s)).^2 - u/g ) * sqrt(vk*u/g) .* besseli(1, 2 * sqrt(vk*u/g)) ./ u;
// vk / g
final double vk_g = vk * alpha;
final double sqrt2sigma = Math.sqrt(2) * sk;
// Specify the function to integrate
UnivariateFunction f = new UnivariateFunction() {
public double value(double u) {
return eval(sqrt2sigma, z, vk_g, g, u);
}
};
// Integrate to infinity is not necessary. The convolution of the function with the
// Gaussian should be adequately sampled using a nxSD around the maximum.
// Find a bracket containing the maximum
double lower, upper;
double maxU = Math.max(1, o);
double rLower = maxU;
double rUpper = maxU + 1;
double f1 = f.value(rLower);
double f2 = f.value(rUpper);
// Calculate the simple integral and the range
double sum = f1 + f2;
boolean searchUp = f2 > f1;
if (searchUp) {
while (f2 > f1) {
f1 = f2;
rUpper += 1;
f2 = f.value(rUpper);
sum += f2;
}
maxU = rUpper - 1;
} else {
// Ensure that u stays above zero
while (f1 > f2 && rLower > 1) {
f2 = f1;
rLower -= 1;
f1 = f.value(rLower);
sum += f1;
}
maxU = (rLower > 1) ? rLower + 1 : rLower;
}
lower = Math.max(0, maxU - 5 * sk);
upper = maxU + 5 * sk;
if (useSimpleIntegration && lower > 0) {
// remaining points in the range
for (double u = rLower - 1; u >= lower; u -= 1) {
sum += f.value(u);
}
for (double u = rUpper + 1; u <= upper; u += 1) {
sum += f.value(u);
}
} else {
// Use Legendre-Gauss integrator
try {
final double relativeAccuracy = 1e-4;
final double absoluteAccuracy = 1e-8;
final int minimalIterationCount = 3;
final int maximalIterationCount = 32;
final int integrationPoints = 16;
// Use an integrator that does not use the boundary since u=0 is undefined (divide by zero)
UnivariateIntegrator i = new IterativeLegendreGaussIntegrator(integrationPoints, relativeAccuracy, absoluteAccuracy, minimalIterationCount, maximalIterationCount);
sum = i.integrate(2000, f, lower, upper);
} catch (TooManyEvaluationsException ex) {
return mortensenApproximation(cij, eta);
}
}
// Compute the final probability
//final double
f1 = z / sqrt2sigma;
final double p = (FastMath.exp(-vk) / (sqrt2pi * sk)) * (FastMath.exp(-(f1 * f1)) + sum);
return (p > minimumProbability) ? p : minimumProbability;
}
}
use of org.apache.commons.math3.util.FastMath.PI in project GDSC-SMLM by aherbert.
the class PCPALMClusters method run.
/*
* (non-Javadoc)
*
* @see ij.plugin.PlugIn#run(java.lang.String)
*/
public void run(String arg) {
SMLMUsageTracker.recordPlugin(this.getClass(), arg);
if (!showDialog())
return;
PCPALMMolecules.logSpacer();
Utils.log(TITLE);
PCPALMMolecules.logSpacer();
long start = System.currentTimeMillis();
HistogramData histogramData;
if (fileInput) {
histogramData = loadHistogram(histogramFile);
} else {
histogramData = doClustering();
}
if (histogramData == null)
return;
float[][] hist = histogramData.histogram;
// Create a histogram of the cluster sizes
String title = TITLE + " Molecules/cluster";
String xTitle = "Molecules/cluster";
String yTitle = "Frequency";
// Create the data required for fitting and plotting
float[] xValues = Utils.createHistogramAxis(hist[0]);
float[] yValues = Utils.createHistogramValues(hist[1]);
// Plot the histogram
float yMax = Maths.max(yValues);
Plot2 plot = new Plot2(title, xTitle, yTitle, xValues, yValues);
if (xValues.length > 0) {
double xPadding = 0.05 * (xValues[xValues.length - 1] - xValues[0]);
plot.setLimits(xValues[0] - xPadding, xValues[xValues.length - 1] + xPadding, 0, yMax * 1.05);
}
Utils.display(title, plot);
HistogramData noiseData = loadNoiseHistogram(histogramData);
if (noiseData != null) {
if (subtractNoise(histogramData, noiseData)) {
// Update the histogram
title += " (noise subtracted)";
xValues = Utils.createHistogramAxis(hist[0]);
yValues = Utils.createHistogramValues(hist[1]);
yMax = Maths.max(yValues);
plot = new Plot2(title, xTitle, yTitle, xValues, yValues);
if (xValues.length > 0) {
double xPadding = 0.05 * (xValues[xValues.length - 1] - xValues[0]);
plot.setLimits(xValues[0] - xPadding, xValues[xValues.length - 1] + xPadding, 0, yMax * 1.05);
}
Utils.display(title, plot);
// Automatically save
if (autoSave) {
String newFilename = Utils.replaceExtension(histogramData.filename, ".noise.tsv");
if (saveHistogram(histogramData, newFilename)) {
Utils.log("Saved noise-subtracted histogram to " + newFilename);
}
}
}
}
// Fit the histogram
double[] fitParameters = fitBinomial(histogramData);
if (fitParameters != null) {
// Add the binomial to the histogram
int n = (int) fitParameters[0];
double p = fitParameters[1];
Utils.log("Optimal fit : N=%d, p=%s", n, Utils.rounded(p));
BinomialDistribution dist = new BinomialDistribution(n, p);
// A zero-truncated binomial was fitted.
// pi is the adjustment factor for the probability density.
double pi = 1 / (1 - dist.probability(0));
if (!fileInput) {
// Calculate the estimated number of clusters from the observed molecules:
// Actual = (Observed / p-value) / N
final double actual = (nMolecules / p) / n;
Utils.log("Estimated number of clusters : (%d / %s) / %d = %s", nMolecules, Utils.rounded(p), n, Utils.rounded(actual));
}
double[] x = new double[n + 2];
double[] y = new double[n + 2];
// Scale the values to match those on the histogram
final double normalisingFactor = count * pi;
for (int i = 0; i <= n; i++) {
x[i] = i + 0.5;
y[i] = dist.probability(i) * normalisingFactor;
}
x[n + 1] = n + 1.5;
y[n + 1] = 0;
// Redraw the plot since the limits may have changed
plot = new Plot2(title, xTitle, yTitle, xValues, yValues);
double xPadding = 0.05 * (xValues[xValues.length - 1] - xValues[0]);
plot.setLimits(xValues[0] - xPadding, xValues[xValues.length - 1] + xPadding, 0, Maths.maxDefault(yMax, y) * 1.05);
plot.setColor(Color.magenta);
plot.addPoints(x, y, Plot2.LINE);
plot.addPoints(x, y, Plot2.CIRCLE);
plot.setColor(Color.black);
Utils.display(title, plot);
}
double seconds = (System.currentTimeMillis() - start) / 1000.0;
String msg = TITLE + " complete : " + seconds + "s";
IJ.showStatus(msg);
Utils.log(msg);
return;
}
use of org.apache.commons.math3.util.FastMath.PI in project narchy by automenta.
the class GreedyGQ method update.
public double update(RealVector x_t, A a_t, double r_tp1, double gamma_tp1, double z_tp1, RealVector x_tp1, A a_tp1) {
rho_t = 0.0;
if (a_t != null && x_t != null) /*!Vectors.isNull(x_t)*/
{
target.update(x_t);
behaviour.update(x_t);
rho_t = target.pi(a_t) / behaviour.pi(a_t);
}
// assert Utils.checkValue(rho_t);
VectorPool pool = VectorPools.pool(prototype, gq.v.getDimension());
RealVector sa_bar_tp1 = pool.newVector();
// if (!Vectors.isNull(x_t) && !Vectors.isNull(x_tp1)) {
if (x_t != null && x_tp1 != null) {
target.update(x_tp1);
for (A a : actions) {
double pi = target.pi(a);
if (pi == 0)
continue;
sa_bar_tp1.combineToSelf(1, pi, toStateAction.stateAction(x_tp1, a));
}
}
RealVector phi_stat = x_t != null ? toStateAction.stateAction(x_t, a_t) : null;
double delta_t = gq.update(phi_stat, rho_t, r_tp1, sa_bar_tp1, z_tp1);
pool.releaseAll();
return delta_t;
}
use of org.apache.commons.math3.util.FastMath.PI in project knime-core by knime.
the class Learner method irlsRls.
/**
* Do a irls step. The result is stored in beta.
*
* @param data over trainings data.
* @param beta parameter vector
* @param rC regressors count
* @param tcC target category count
* @throws CanceledExecutionException when method is cancelled
*/
private void irlsRls(final RegressionTrainingData data, final RealMatrix beta, final int rC, final int tcC, final ExecutionMonitor exec) throws CanceledExecutionException {
Iterator<RegressionTrainingRow> iter = data.iterator();
long rowCount = 0;
int dim = (rC + 1) * (tcC - 1);
RealMatrix xTwx = new Array2DRowRealMatrix(dim, dim);
RealMatrix xTyu = new Array2DRowRealMatrix(dim, 1);
RealMatrix x = new Array2DRowRealMatrix(1, rC + 1);
RealMatrix eBetaTx = new Array2DRowRealMatrix(1, tcC - 1);
RealMatrix pi = new Array2DRowRealMatrix(1, tcC - 1);
final long totalRowCount = data.getRowCount();
while (iter.hasNext()) {
rowCount++;
RegressionTrainingRow row = iter.next();
exec.checkCanceled();
exec.setProgress(rowCount / (double) totalRowCount, "Row " + rowCount + "/" + totalRowCount);
x.setEntry(0, 0, 1);
x.setSubMatrix(row.getParameter().getData(), 0, 1);
for (int k = 0; k < tcC - 1; k++) {
RealMatrix betaITx = x.multiply(beta.getSubMatrix(0, 0, k * (rC + 1), (k + 1) * (rC + 1) - 1).transpose());
eBetaTx.setEntry(0, k, Math.exp(betaITx.getEntry(0, 0)));
}
double sumEBetaTx = 0;
for (int k = 0; k < tcC - 1; k++) {
sumEBetaTx += eBetaTx.getEntry(0, k);
}
for (int k = 0; k < tcC - 1; k++) {
double pik = eBetaTx.getEntry(0, k) / (1 + sumEBetaTx);
pi.setEntry(0, k, pik);
}
// fill the diagonal blocks of matrix xTwx (k = k')
for (int k = 0; k < tcC - 1; k++) {
for (int i = 0; i < rC + 1; i++) {
for (int ii = i; ii < rC + 1; ii++) {
int o = k * (rC + 1);
double v = xTwx.getEntry(o + i, o + ii);
double w = pi.getEntry(0, k) * (1 - pi.getEntry(0, k));
v += x.getEntry(0, i) * w * x.getEntry(0, ii);
xTwx.setEntry(o + i, o + ii, v);
xTwx.setEntry(o + ii, o + i, v);
}
}
}
// fill the rest of xTwx (k != k')
for (int k = 0; k < tcC - 1; k++) {
for (int kk = k + 1; kk < tcC - 1; kk++) {
for (int i = 0; i < rC + 1; i++) {
for (int ii = i; ii < rC + 1; ii++) {
int o1 = k * (rC + 1);
int o2 = kk * (rC + 1);
double v = xTwx.getEntry(o1 + i, o2 + ii);
double w = -pi.getEntry(0, k) * pi.getEntry(0, kk);
v += x.getEntry(0, i) * w * x.getEntry(0, ii);
xTwx.setEntry(o1 + i, o2 + ii, v);
xTwx.setEntry(o1 + ii, o2 + i, v);
xTwx.setEntry(o2 + ii, o1 + i, v);
xTwx.setEntry(o2 + i, o1 + ii, v);
}
}
}
}
int g = (int) row.getTarget();
// fill matrix xTyu
for (int k = 0; k < tcC - 1; k++) {
for (int i = 0; i < rC + 1; i++) {
int o = k * (rC + 1);
double v = xTyu.getEntry(o + i, 0);
double y = k == g ? 1 : 0;
v += (y - pi.getEntry(0, k)) * x.getEntry(0, i);
xTyu.setEntry(o + i, 0, v);
}
}
}
if (m_penaltyTerm > 0.0) {
RealMatrix stdError = getStdErrorMatrix(xTwx);
// do not penalize the constant terms
for (int i = 0; i < tcC - 1; i++) {
stdError.setEntry(i * (rC + 1), i * (rC + 1), 0);
}
xTwx = xTwx.add(stdError.scalarMultiply(-0.00001));
}
exec.checkCanceled();
b = xTwx.multiply(beta.transpose()).add(xTyu);
A = xTwx;
if (rowCount < A.getColumnDimension()) {
throw new IllegalStateException("The dataset must have at least " + A.getColumnDimension() + " rows, but it has only " + rowCount + " rows. It is recommended to use a " + "larger dataset in order to increase accuracy.");
}
DecompositionSolver solver = new QRDecomposition(A).getSolver();
boolean isNonSingular = solver.isNonSingular();
if (isNonSingular) {
RealMatrix betaNew = solver.solve(b);
beta.setSubMatrix(betaNew.transpose().getData(), 0, 0);
} else {
throw new RuntimeException(FAILING_MSG);
}
}
use of org.apache.commons.math3.util.FastMath.PI in project knime-core by knime.
the class Learner method irlsRls.
/**
* Do a irls step. The result is stored in beta.
*
* @param data over trainings data.
* @param beta parameter vector
* @param rC regressors count
* @param tcC target category count
* @throws CanceledExecutionException when method is cancelled
*/
private void irlsRls(final RegressionTrainingData data, final RealMatrix beta, final int rC, final int tcC, final ExecutionMonitor exec) throws CanceledExecutionException {
Iterator<RegressionTrainingRow> iter = data.iterator();
long rowCount = 0;
int dim = (rC + 1) * (tcC - 1);
RealMatrix xTwx = new Array2DRowRealMatrix(dim, dim);
RealMatrix xTyu = new Array2DRowRealMatrix(dim, 1);
RealMatrix x = new Array2DRowRealMatrix(1, rC + 1);
RealMatrix eBetaTx = new Array2DRowRealMatrix(1, tcC - 1);
RealMatrix pi = new Array2DRowRealMatrix(1, tcC - 1);
final long totalRowCount = data.getRowCount();
while (iter.hasNext()) {
rowCount++;
RegressionTrainingRow row = iter.next();
exec.checkCanceled();
exec.setProgress(rowCount / (double) totalRowCount, "Row " + rowCount + "/" + totalRowCount);
x.setEntry(0, 0, 1);
x.setSubMatrix(row.getParameter().getData(), 0, 1);
for (int k = 0; k < tcC - 1; k++) {
RealMatrix betaITx = x.multiply(beta.getSubMatrix(0, 0, k * (rC + 1), (k + 1) * (rC + 1) - 1).transpose());
eBetaTx.setEntry(0, k, Math.exp(betaITx.getEntry(0, 0)));
}
double sumEBetaTx = 0;
for (int k = 0; k < tcC - 1; k++) {
sumEBetaTx += eBetaTx.getEntry(0, k);
}
for (int k = 0; k < tcC - 1; k++) {
double pik = eBetaTx.getEntry(0, k) / (1 + sumEBetaTx);
pi.setEntry(0, k, pik);
}
// fill the diagonal blocks of matrix xTwx (k = k')
for (int k = 0; k < tcC - 1; k++) {
for (int i = 0; i < rC + 1; i++) {
for (int ii = i; ii < rC + 1; ii++) {
int o = k * (rC + 1);
double v = xTwx.getEntry(o + i, o + ii);
double w = pi.getEntry(0, k) * (1 - pi.getEntry(0, k));
v += x.getEntry(0, i) * w * x.getEntry(0, ii);
xTwx.setEntry(o + i, o + ii, v);
xTwx.setEntry(o + ii, o + i, v);
}
}
}
// fill the rest of xTwx (k != k')
for (int k = 0; k < tcC - 1; k++) {
for (int kk = k + 1; kk < tcC - 1; kk++) {
for (int i = 0; i < rC + 1; i++) {
for (int ii = i; ii < rC + 1; ii++) {
int o1 = k * (rC + 1);
int o2 = kk * (rC + 1);
double v = xTwx.getEntry(o1 + i, o2 + ii);
double w = -pi.getEntry(0, k) * pi.getEntry(0, kk);
v += x.getEntry(0, i) * w * x.getEntry(0, ii);
xTwx.setEntry(o1 + i, o2 + ii, v);
xTwx.setEntry(o1 + ii, o2 + i, v);
xTwx.setEntry(o2 + ii, o1 + i, v);
xTwx.setEntry(o2 + i, o1 + ii, v);
}
}
}
}
int g = (int) row.getTarget();
// fill matrix xTyu
for (int k = 0; k < tcC - 1; k++) {
for (int i = 0; i < rC + 1; i++) {
int o = k * (rC + 1);
double v = xTyu.getEntry(o + i, 0);
double y = k == g ? 1 : 0;
v += (y - pi.getEntry(0, k)) * x.getEntry(0, i);
xTyu.setEntry(o + i, 0, v);
}
}
}
if (m_penaltyTerm > 0.0) {
RealMatrix stdError = getStdErrorMatrix(xTwx);
// do not penalize the constant terms
for (int i = 0; i < tcC - 1; i++) {
stdError.setEntry(i * (rC + 1), i * (rC + 1), 0);
}
xTwx = xTwx.add(stdError.scalarMultiply(-0.00001));
}
exec.checkCanceled();
b = xTwx.multiply(beta.transpose()).add(xTyu);
A = xTwx;
if (rowCount < A.getColumnDimension()) {
throw new IllegalStateException("The dataset must have at least " + A.getColumnDimension() + " rows, but it has only " + rowCount + " rows. It is recommended to use a " + "larger dataset in order to increase accuracy.");
}
DecompositionSolver solver = new SingularValueDecomposition(A).getSolver();
// boolean isNonSingular = solver.isNonSingular();
// if (isNonSingular) {
RealMatrix betaNew = solver.solve(b);
beta.setSubMatrix(betaNew.transpose().getData(), 0, 0);
// } else {
// throw new RuntimeException(FAILING_MSG);
// }
}
Aggregations