use of uk.ac.sussex.gdsc.smlm.function.cspline.CubicSplineCalculator in project GDSC-SMLM by aherbert.
the class CubicSplineManager method createCubicSpline.
/**
* Creates the cubic spline.
*
* @param imagePsf the image PSF details
* @param image the image
* @param singlePrecision Set to true to use single precision (float values) to store the cubic
* spline coefficients
* @return the cubic spline PSF
*/
public static CubicSplinePsf createCubicSpline(ImagePSFOrBuilder imagePsf, ImageStack image, final boolean singlePrecision) {
final int maxx = image.getWidth();
final int maxy = image.getHeight();
final int maxz = image.getSize();
final float[][] psf = new float[maxz][];
for (int z = 0; z < maxz; z++) {
psf[z] = ImageJImageConverter.getData(image.getPixels(z + 1), null);
}
// We reduce by a factor of 3
final int maxi = (maxx - 1) / 3;
final int maxj = (maxy - 1) / 3;
final int maxk = (maxz - 1) / 3;
final int size = maxi * maxj;
final CustomTricubicFunction[][] splines = new CustomTricubicFunction[maxk][size];
final int threadCount = Prefs.getThreads();
final Ticker ticker = ImageJUtils.createTicker((long) maxi * maxj * maxk, threadCount);
final ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
final LocalList<Future<?>> futures = new LocalList<>(maxk);
// spline node along each dimension, i.e. dimension length = n*3 + 1 with n the number of nodes.
for (int k = 0; k < maxk; k++) {
final int kk = k;
futures.add(threadPool.submit(() -> {
final CubicSplineCalculator calc = new CubicSplineCalculator();
final double[] value = new double[64];
final int zz = 3 * kk;
for (int j = 0, index = 0; j < maxj; j++) {
// 4x4 block origin in the XY data
int index0 = 3 * j * maxx;
for (int i = 0; i < maxi; i++, index++) {
ticker.tick();
int count = 0;
for (int z = 0; z < 4; z++) {
final float[] data = psf[zz + z];
for (int y = 0; y < 4; y++) {
for (int x = 0, ii = index0 + y * maxx; x < 4; x++) {
value[count++] = data[ii++];
}
}
}
splines[kk][index] = CustomTricubicFunctionUtils.create(calc.compute(value));
if (singlePrecision) {
splines[kk][index] = splines[kk][index].toSinglePrecision();
}
index0 += 3;
}
}
}));
}
ticker.stop();
threadPool.shutdown();
ConcurrencyUtils.waitForCompletionUnchecked(futures);
// Normalise
double maxSum = 0;
for (int k = 0; k < maxk; k++) {
double sum = 0;
for (int i = 0; i < size; i++) {
sum += splines[k][i].value000();
}
if (maxSum < sum) {
maxSum = sum;
}
}
if (maxSum == 0) {
throw new IllegalStateException("The cubic spline has no maximum signal");
}
final double scale = 1.0 / maxSum;
for (int k = 0; k < maxk; k++) {
for (int i = 0; i < size; i++) {
splines[k][i] = splines[k][i].scale(scale);
}
}
// Create on an integer scale
final CubicSplineData f = new CubicSplineData(maxi, maxj, splines);
// Create a new info with the PSF details
final ImagePSF.Builder b = ImagePSF.newBuilder();
b.setImageCount(imagePsf.getImageCount());
// Reducing the image has the effect of enlarging the pixel size
b.setPixelSize(imagePsf.getPixelSize() * 3.0);
b.setPixelDepth(imagePsf.getPixelDepth() * 3.0);
// The centre has to be moved as we reduced the image size by 3.
// In the ImagePSF the XY centre puts 0.5 at the centre of the pixel.
// The spline puts 0,0 at the centre of each pixel for convenience.
double cx = maxi / 2.0;
if (imagePsf.getXCentre() != 0) {
cx = (imagePsf.getXCentre() - 0.5) / 3;
}
double cy = maxj / 2.0;
if (imagePsf.getYCentre() != 0) {
cy = (imagePsf.getYCentre() - 0.5) / 3;
}
double cz = maxk / 2.0;
if (imagePsf.getZCentre() != 0) {
cz = imagePsf.getZCentre() / 3;
} else if (imagePsf.getCentreImage() != 0) {
cz = (imagePsf.getCentreImage() - 1) / 3.0;
}
b.setXCentre(cx);
b.setYCentre(cy);
b.setZCentre(cz);
return new CubicSplinePsf(b.build(), f);
}
use of uk.ac.sussex.gdsc.smlm.function.cspline.CubicSplineCalculator in project GDSC-SMLM by aherbert.
the class Image3DAligner method align.
/**
* Align the image with the reference with sub-pixel accuracy. Compute the translation required to
* move the target image onto the reference image for maximum correlation.
*
* @param target the target
* @param refinements the maximum number of refinements for sub-pixel accuracy
* @param error the error for sub-pixel accuracy (i.e. stop when improvements are less than this
* error)
* @return [x,y,z,value]
* @throws IllegalArgumentException If any dimension is less than 2, or if larger than the
* initialised reference
*/
private double[] align(DhtData target, int refinements, double error) {
// Multiply by the reference. This allows the reference to be shared across threads.
final DoubleDht3D correlation = target.dht.conjugateMultiply(reference.dht, buffer);
// Store for reuse
buffer = correlation.getData();
correlation.inverseTransform();
correlation.swapOctants();
// Normalise:
// ( Σ xiyi - nx̄ӯ ) / ( (Σ xi^2 - nx̄^2) (Σ yi^2 - nӯ^2) )^0.5
//
// (sumXy - sumX*sumY/n) / sqrt( (sumXx - sumX^2 / n) * (sumYy - sumY^2 / n) )
// Only do this over the range where at least half the original images overlap,
// i.e. the insert point of one will be the middle of the other when shifted.
int ix = Math.min(reference.ix, target.ix);
int iy = Math.min(reference.iy, target.iy);
int iz = Math.min(reference.iz, target.iz);
int ixw = Math.max(reference.ix + reference.width, target.ix + target.width);
int iyh = Math.max(reference.iy + reference.height, target.iy + target.height);
int izd = Math.max(reference.iz + reference.depth, target.iz + target.depth);
if (minimumDimensionOverlap > 0) {
final double f = (1 - minimumDimensionOverlap) / 2;
final int ux = (int) (Math.round(Math.min(reference.width, target.width) * f));
final int uy = (int) (Math.round(Math.min(reference.height, target.height) * f));
final int uz = (int) (Math.round(Math.min(reference.depth, target.depth) * f));
ix += ux;
ixw -= ux;
iy += uy;
iyh -= uy;
iz += uz;
izd -= uz;
}
cropDimensions = new int[] { ix, iy, iz, ixw - ix, iyh - iy, izd - iz };
// The maximum correlation unnormalised. Since this is unnormalised
// it will be biased towards the centre of the image. This is used
// to restrict the bounds for finding the maximum of the normalised correlation
// which should be close to this.
int maxi = correlation.findMaxIndex(ix, iy, iz, cropDimensions[3], cropDimensions[4], cropDimensions[5]);
// Check in the spatial domain
checkCorrelation(target, correlation, maxi);
// Compute sum from rolling sum using:
// sum(x,y,z,w,h,d) =
// + s(x+w-1,y+h-1,z+d-1)
// - s(x-1,y+h-1,z+d-1)
// - s(x+w-1,y-1,z+d-1)
// + s(x-1,y-1,z+d-1)
// /* Image above must be subtracted so reverse sign*/
// - s(x+w-1,y+h-1,z-1)
// + s(x-1,y+h-1,z-1)
// + s(x+w-1,y-1,z-1)
// - s(x-1,y-1,z-1)
// Note:
// s(i,j,k) = 0 when either i,j,k < 0
// i = imax when i>imax
// j = jmax when j>jmax
// k = kmax when k>kmax
// Note: The correlation is for the movement of the reference over the target
final int nc_2 = nc / 2;
final int nr_2 = nr / 2;
final int ns_2 = ns / 2;
final int[] centre = new int[] { nc_2, nr_2, ns_2 };
// Compute the shift from the centre
final int dx = nc_2 - ix;
final int dy = nr_2 - iy;
final int dz = ns_2 - iz;
// For the reference (moved -dx,-dy,-dz over the target)
int rx = -dx;
int ry = -dy;
int rz = -dz;
// For the target (moved dx,dy,dz over the reference)
int tx = dx;
int ty = dy;
int tz = dz;
// Precompute the x-1,x+w-1,y-1,y+h-1
final int nx = cropDimensions[3];
final int[] rx1 = new int[nx];
final int[] rxw1 = new int[nx];
final int[] tx1 = new int[nx];
final int[] txw1 = new int[nx];
final int[] width = new int[nx];
for (int c = ix, i = 0; c < ixw; c++, i++) {
rx1[i] = Math.max(-1, rx - 1);
rxw1[i] = Math.min(nc, rx + nc) - 1;
rx++;
tx1[i] = Math.max(-1, tx - 1);
txw1[i] = Math.min(nc, tx + nc) - 1;
tx--;
width[i] = rxw1[i] - rx1[i];
}
final int ny = cropDimensions[4];
final int[] ry1 = new int[ny];
final int[] ryh1 = new int[ny];
final int[] ty1 = new int[ny];
final int[] tyh1 = new int[ny];
final int[] h = new int[ny];
for (int r = iy, j = 0; r < iyh; r++, j++) {
ry1[j] = Math.max(-1, ry - 1);
ryh1[j] = Math.min(nr, ry + nr) - 1;
ry++;
ty1[j] = Math.max(-1, ty - 1);
tyh1[j] = Math.min(nr, ty + nr) - 1;
ty--;
h[j] = ryh1[j] - ry1[j];
}
final double[] rs = reference.sum;
final double[] rss = reference.sumSq;
final double[] ts = target.sum;
final double[] tss = target.sumSq;
final double[] rsum = new double[2];
final double[] tsum = new double[2];
final int size = Math.min(reference.size, target.size);
final int minimumN = (int) (Math.round(size * minimumOverlap));
int maxj = -1;
double max = 0;
for (int s = iz; s < izd; s++) {
// Compute the z-1,z+d-1
final int rz_1 = Math.max(-1, rz - 1);
final int rz_d_1 = Math.min(ns, rz + ns) - 1;
rz++;
final int tz_1 = Math.max(-1, tz - 1);
final int tz_d_1 = Math.min(ns, tz + ns) - 1;
tz--;
final int d = rz_d_1 - rz_1;
for (int r = iy, j = 0; r < iyh; r++, j++) {
final int base = s * nrByNc + r * nc;
final int hd = h[j] * d;
for (int c = ix, i = 0; c < ixw; c++, i++) {
final double sumXy = buffer[base + c];
compute(rx1[i], ry1[j], rz_1, rxw1[i], ryh1[j], rz_d_1, width[i], h[j], d, rs, rss, rsum);
compute(tx1[i], ty1[j], tz_1, txw1[i], tyh1[j], tz_d_1, width[i], h[j], d, ts, tss, tsum);
// Compute the correlation
// (sumXy - sumX*sumY/n) / sqrt( (sumXx - sumX^2 / n) * (sumYy - sumY^2 / n) )
final int n = width[i] * hd;
final double numerator = sumXy - (rsum[X] * tsum[Y] / n);
final double denominator1 = rsum[XX] - (rsum[X] * rsum[X] / n);
final double denominator2 = tsum[YY] - (tsum[Y] * tsum[Y] / n);
double corr;
if (denominator1 == 0 || denominator2 == 0) {
// If there is data and all the variances are the same then correlation is perfect
if (rsum[XX] == tsum[YY] && rsum[XX] == sumXy && rsum[XX] > 0) {
corr = 1;
} else {
corr = 0;
}
} else {
// Leave as raw for debugging, i.e. do not clip to range [-1:1]
corr = numerator / Math.sqrt(denominator1 * denominator2);
}
buffer[base + c] = corr;
if (n < minimumN) {
continue;
}
// Check normalisation with some margin for error
if (corr > 1.0001) {
// It is likely to occur at the bounds.
continue;
}
if (corr > max) {
max = corr;
maxj = base + c;
} else if (corr == max) {
// Get shift from centre
final int[] xyz1 = correlation.getXyz(maxj);
final int[] xyz2 = correlation.getXyz(base + c);
int d1 = 0;
int d2 = 0;
for (int k = 0; k < 3; k++) {
d1 += MathUtils.pow2(xyz1[k] - centre[k]);
d2 += MathUtils.pow2(xyz2[k] - centre[k]);
}
if (d2 < d1) {
max = corr;
maxj = base + c;
}
}
}
}
}
// The maximum correlation with normalisation
// correlation.findMaxIndex(ix, iy, iz, iw - ix, ih - iy, id - iz);
maxi = maxj;
final int[] xyz = correlation.getXyz(maxi);
// Report the shift required to move from the centre of the target image to the reference
// @formatter:off
final double[] result = new double[] { nc_2 - xyz[0], nr_2 - xyz[1], ns_2 - xyz[2], buffer[maxi] };
if (refinements > 0) {
// Create a cubic spline using a small region of pixels around the maximum
if (calc == null) {
calc = new CubicSplineCalculator();
}
// Avoid out-of-bounds errors. Only use the range that was normalised
final int x = MathUtils.clip(ix, ixw - 4, xyz[0] - 1);
final int y = MathUtils.clip(iy, iyh - 4, xyz[1] - 1);
final int z = MathUtils.clip(iz, izd - 4, xyz[2] - 1);
final DoubleImage3D crop = correlation.crop(x, y, z, 4, 4, 4, region);
region = crop.getData();
final CustomTricubicFunction f = CustomTricubicFunctionUtils.create(calc.compute(region));
// Find the maximum starting at the current origin
final int ox = xyz[0] - x;
final int oy = xyz[1] - y;
final int oz = xyz[2] - z;
// Scale to the cubic spline dimensions of 0-1
final double[] origin = new double[] { ox / 3.0, oy / 3.0, oz / 3.0 };
// Simple condensing search
if (searchMode == SearchMode.BINARY) {
// Can this use the current origin as a start point?
// Currently we evaluate 8-cube vertices. A better search
// would evaluate 27 points around the optimum, pick the best then condense
// the range.
final double[] optimum = f.search(true, refinements, relativeThreshold, -1);
final double value = optimum[3];
if (value > result[3]) {
result[3] = value;
// Convert the maximum back with scaling
for (int i = 0; i < 3; i++) {
result[i] -= (optimum[i] - origin[i]) * 3.0;
}
return result;
}
} else {
// Gradient search
try {
final SplineFunction sf = new SplineFunction(f, origin);
final BoundedNonLinearConjugateGradientOptimizer optimiser = new BoundedNonLinearConjugateGradientOptimizer(BoundedNonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, // set the number of refinements
new SimpleValueChecker(relativeThreshold, -1, refinements));
final PointValuePair opt = optimiser.optimize(maxEvaluations, bounds, GoalType.MINIMIZE, new InitialGuess(origin), // Scale the error for the position check
new PositionChecker(-1, error / 3.0), new ObjectiveFunction(sf::value), new ObjectiveFunctionGradient(point -> {
// This must be new each time
final double[] partialDerivative1 = new double[3];
sf.value(point, partialDerivative1);
return partialDerivative1;
}));
// Check it is higher. Invert since we did a minimisation.
final double value = -opt.getValue();
if (value > result[3]) {
result[3] = value;
// Convert the maximum back with scaling
final double[] optimum = opt.getPointRef();
for (int i = 0; i < 3; i++) {
result[i] -= (optimum[i] - origin[i]) * 3.0;
}
return result;
}
} catch (final Exception ex) {
// Ignore this
}
}
}
return result;
}
Aggregations