Search in sources :

Example 6 with PositionChecker

use of uk.ac.sussex.gdsc.smlm.math3.optim.PositionChecker in project GDSC-SMLM by aherbert.

the class Image3DAligner method align.

/**
 * Align the image with the reference with sub-pixel accuracy. Compute the translation required to
 * move the target image onto the reference image for maximum correlation.
 *
 * @param target the target
 * @param refinements the maximum number of refinements for sub-pixel accuracy
 * @param error the error for sub-pixel accuracy (i.e. stop when improvements are less than this
 *        error)
 * @return [x,y,z,value]
 * @throws IllegalArgumentException If any dimension is less than 2, or if larger than the
 *         initialised reference
 */
private double[] align(DhtData target, int refinements, double error) {
    // Multiply by the reference. This allows the reference to be shared across threads.
    final DoubleDht3D correlation = target.dht.conjugateMultiply(reference.dht, buffer);
    // Store for reuse
    buffer = correlation.getData();
    correlation.inverseTransform();
    correlation.swapOctants();
    // Normalise:
    // ( Σ xiyi - nx̄ӯ ) / ( (Σ xi^2 - nx̄^2) (Σ yi^2 - nӯ^2) )^0.5
    // 
    // (sumXy - sumX*sumY/n) / sqrt( (sumXx - sumX^2 / n) * (sumYy - sumY^2 / n) )
    // Only do this over the range where at least half the original images overlap,
    // i.e. the insert point of one will be the middle of the other when shifted.
    int ix = Math.min(reference.ix, target.ix);
    int iy = Math.min(reference.iy, target.iy);
    int iz = Math.min(reference.iz, target.iz);
    int ixw = Math.max(reference.ix + reference.width, target.ix + target.width);
    int iyh = Math.max(reference.iy + reference.height, target.iy + target.height);
    int izd = Math.max(reference.iz + reference.depth, target.iz + target.depth);
    if (minimumDimensionOverlap > 0) {
        final double f = (1 - minimumDimensionOverlap) / 2;
        final int ux = (int) (Math.round(Math.min(reference.width, target.width) * f));
        final int uy = (int) (Math.round(Math.min(reference.height, target.height) * f));
        final int uz = (int) (Math.round(Math.min(reference.depth, target.depth) * f));
        ix += ux;
        ixw -= ux;
        iy += uy;
        iyh -= uy;
        iz += uz;
        izd -= uz;
    }
    cropDimensions = new int[] { ix, iy, iz, ixw - ix, iyh - iy, izd - iz };
    // The maximum correlation unnormalised. Since this is unnormalised
    // it will be biased towards the centre of the image. This is used
    // to restrict the bounds for finding the maximum of the normalised correlation
    // which should be close to this.
    int maxi = correlation.findMaxIndex(ix, iy, iz, cropDimensions[3], cropDimensions[4], cropDimensions[5]);
    // Check in the spatial domain
    checkCorrelation(target, correlation, maxi);
    // Compute sum from rolling sum using:
    // sum(x,y,z,w,h,d) =
    // + s(x+w-1,y+h-1,z+d-1)
    // - s(x-1,y+h-1,z+d-1)
    // - s(x+w-1,y-1,z+d-1)
    // + s(x-1,y-1,z+d-1)
    // /* Image above must be subtracted so reverse sign*/
    // - s(x+w-1,y+h-1,z-1)
    // + s(x-1,y+h-1,z-1)
    // + s(x+w-1,y-1,z-1)
    // - s(x-1,y-1,z-1)
    // Note:
    // s(i,j,k) = 0 when either i,j,k < 0
    // i = imax when i>imax
    // j = jmax when j>jmax
    // k = kmax when k>kmax
    // Note: The correlation is for the movement of the reference over the target
    final int nc_2 = nc / 2;
    final int nr_2 = nr / 2;
    final int ns_2 = ns / 2;
    final int[] centre = new int[] { nc_2, nr_2, ns_2 };
    // Compute the shift from the centre
    final int dx = nc_2 - ix;
    final int dy = nr_2 - iy;
    final int dz = ns_2 - iz;
    // For the reference (moved -dx,-dy,-dz over the target)
    int rx = -dx;
    int ry = -dy;
    int rz = -dz;
    // For the target (moved dx,dy,dz over the reference)
    int tx = dx;
    int ty = dy;
    int tz = dz;
    // Precompute the x-1,x+w-1,y-1,y+h-1
    final int nx = cropDimensions[3];
    final int[] rx1 = new int[nx];
    final int[] rxw1 = new int[nx];
    final int[] tx1 = new int[nx];
    final int[] txw1 = new int[nx];
    final int[] width = new int[nx];
    for (int c = ix, i = 0; c < ixw; c++, i++) {
        rx1[i] = Math.max(-1, rx - 1);
        rxw1[i] = Math.min(nc, rx + nc) - 1;
        rx++;
        tx1[i] = Math.max(-1, tx - 1);
        txw1[i] = Math.min(nc, tx + nc) - 1;
        tx--;
        width[i] = rxw1[i] - rx1[i];
    }
    final int ny = cropDimensions[4];
    final int[] ry1 = new int[ny];
    final int[] ryh1 = new int[ny];
    final int[] ty1 = new int[ny];
    final int[] tyh1 = new int[ny];
    final int[] h = new int[ny];
    for (int r = iy, j = 0; r < iyh; r++, j++) {
        ry1[j] = Math.max(-1, ry - 1);
        ryh1[j] = Math.min(nr, ry + nr) - 1;
        ry++;
        ty1[j] = Math.max(-1, ty - 1);
        tyh1[j] = Math.min(nr, ty + nr) - 1;
        ty--;
        h[j] = ryh1[j] - ry1[j];
    }
    final double[] rs = reference.sum;
    final double[] rss = reference.sumSq;
    final double[] ts = target.sum;
    final double[] tss = target.sumSq;
    final double[] rsum = new double[2];
    final double[] tsum = new double[2];
    final int size = Math.min(reference.size, target.size);
    final int minimumN = (int) (Math.round(size * minimumOverlap));
    int maxj = -1;
    double max = 0;
    for (int s = iz; s < izd; s++) {
        // Compute the z-1,z+d-1
        final int rz_1 = Math.max(-1, rz - 1);
        final int rz_d_1 = Math.min(ns, rz + ns) - 1;
        rz++;
        final int tz_1 = Math.max(-1, tz - 1);
        final int tz_d_1 = Math.min(ns, tz + ns) - 1;
        tz--;
        final int d = rz_d_1 - rz_1;
        for (int r = iy, j = 0; r < iyh; r++, j++) {
            final int base = s * nrByNc + r * nc;
            final int hd = h[j] * d;
            for (int c = ix, i = 0; c < ixw; c++, i++) {
                final double sumXy = buffer[base + c];
                compute(rx1[i], ry1[j], rz_1, rxw1[i], ryh1[j], rz_d_1, width[i], h[j], d, rs, rss, rsum);
                compute(tx1[i], ty1[j], tz_1, txw1[i], tyh1[j], tz_d_1, width[i], h[j], d, ts, tss, tsum);
                // Compute the correlation
                // (sumXy - sumX*sumY/n) / sqrt( (sumXx - sumX^2 / n) * (sumYy - sumY^2 / n) )
                final int n = width[i] * hd;
                final double numerator = sumXy - (rsum[X] * tsum[Y] / n);
                final double denominator1 = rsum[XX] - (rsum[X] * rsum[X] / n);
                final double denominator2 = tsum[YY] - (tsum[Y] * tsum[Y] / n);
                double corr;
                if (denominator1 == 0 || denominator2 == 0) {
                    // If there is data and all the variances are the same then correlation is perfect
                    if (rsum[XX] == tsum[YY] && rsum[XX] == sumXy && rsum[XX] > 0) {
                        corr = 1;
                    } else {
                        corr = 0;
                    }
                } else {
                    // Leave as raw for debugging, i.e. do not clip to range [-1:1]
                    corr = numerator / Math.sqrt(denominator1 * denominator2);
                }
                buffer[base + c] = corr;
                if (n < minimumN) {
                    continue;
                }
                // Check normalisation with some margin for error
                if (corr > 1.0001) {
                    // It is likely to occur at the bounds.
                    continue;
                }
                if (corr > max) {
                    max = corr;
                    maxj = base + c;
                } else if (corr == max) {
                    // Get shift from centre
                    final int[] xyz1 = correlation.getXyz(maxj);
                    final int[] xyz2 = correlation.getXyz(base + c);
                    int d1 = 0;
                    int d2 = 0;
                    for (int k = 0; k < 3; k++) {
                        d1 += MathUtils.pow2(xyz1[k] - centre[k]);
                        d2 += MathUtils.pow2(xyz2[k] - centre[k]);
                    }
                    if (d2 < d1) {
                        max = corr;
                        maxj = base + c;
                    }
                }
            }
        }
    }
    // The maximum correlation with normalisation
    // correlation.findMaxIndex(ix, iy, iz, iw - ix, ih - iy, id - iz);
    maxi = maxj;
    final int[] xyz = correlation.getXyz(maxi);
    // Report the shift required to move from the centre of the target image to the reference
    // @formatter:off
    final double[] result = new double[] { nc_2 - xyz[0], nr_2 - xyz[1], ns_2 - xyz[2], buffer[maxi] };
    if (refinements > 0) {
        // Create a cubic spline using a small region of pixels around the maximum
        if (calc == null) {
            calc = new CubicSplineCalculator();
        }
        // Avoid out-of-bounds errors. Only use the range that was normalised
        final int x = MathUtils.clip(ix, ixw - 4, xyz[0] - 1);
        final int y = MathUtils.clip(iy, iyh - 4, xyz[1] - 1);
        final int z = MathUtils.clip(iz, izd - 4, xyz[2] - 1);
        final DoubleImage3D crop = correlation.crop(x, y, z, 4, 4, 4, region);
        region = crop.getData();
        final CustomTricubicFunction f = CustomTricubicFunctionUtils.create(calc.compute(region));
        // Find the maximum starting at the current origin
        final int ox = xyz[0] - x;
        final int oy = xyz[1] - y;
        final int oz = xyz[2] - z;
        // Scale to the cubic spline dimensions of 0-1
        final double[] origin = new double[] { ox / 3.0, oy / 3.0, oz / 3.0 };
        // Simple condensing search
        if (searchMode == SearchMode.BINARY) {
            // Can this use the current origin as a start point?
            // Currently we evaluate 8-cube vertices. A better search
            // would evaluate 27 points around the optimum, pick the best then condense
            // the range.
            final double[] optimum = f.search(true, refinements, relativeThreshold, -1);
            final double value = optimum[3];
            if (value > result[3]) {
                result[3] = value;
                // Convert the maximum back with scaling
                for (int i = 0; i < 3; i++) {
                    result[i] -= (optimum[i] - origin[i]) * 3.0;
                }
                return result;
            }
        } else {
            // Gradient search
            try {
                final SplineFunction sf = new SplineFunction(f, origin);
                final BoundedNonLinearConjugateGradientOptimizer optimiser = new BoundedNonLinearConjugateGradientOptimizer(BoundedNonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, // set the number of refinements
                new SimpleValueChecker(relativeThreshold, -1, refinements));
                final PointValuePair opt = optimiser.optimize(maxEvaluations, bounds, GoalType.MINIMIZE, new InitialGuess(origin), // Scale the error for the position check
                new PositionChecker(-1, error / 3.0), new ObjectiveFunction(sf::value), new ObjectiveFunctionGradient(point -> {
                    // This must be new each time
                    final double[] partialDerivative1 = new double[3];
                    sf.value(point, partialDerivative1);
                    return partialDerivative1;
                }));
                // Check it is higher. Invert since we did a minimisation.
                final double value = -opt.getValue();
                if (value > result[3]) {
                    result[3] = value;
                    // Convert the maximum back with scaling
                    final double[] optimum = opt.getPointRef();
                    for (int i = 0; i < 3; i++) {
                        result[i] -= (optimum[i] - origin[i]) * 3.0;
                    }
                    return result;
                }
            } catch (final Exception ex) {
            // Ignore this
            }
        }
    }
    return result;
}
Also used : CubicSplineCalculator(uk.ac.sussex.gdsc.smlm.function.cspline.CubicSplineCalculator) Arrays(java.util.Arrays) BoundedNonLinearConjugateGradientOptimizer(uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) CubicSplinePosition(uk.ac.sussex.gdsc.core.math.interpolation.CubicSplinePosition) ImageProcessor(ij.process.ImageProcessor) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) PointValuePair(org.apache.commons.math3.optim.PointValuePair) Logger(java.util.logging.Logger) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) CustomTricubicFunctionUtils(uk.ac.sussex.gdsc.core.math.interpolation.CustomTricubicFunctionUtils) CustomTricubicFunction(uk.ac.sussex.gdsc.core.math.interpolation.CustomTricubicFunction) ImageWindow(uk.ac.sussex.gdsc.core.utils.ImageWindow) PositionChecker(uk.ac.sussex.gdsc.smlm.math3.optim.PositionChecker) ImageStack(ij.ImageStack) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) InitialGuess(org.apache.commons.math3.optim.InitialGuess) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) MaxEval(org.apache.commons.math3.optim.MaxEval) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) DoubleEquality(uk.ac.sussex.gdsc.core.utils.DoubleEquality) CubicSplineCalculator(uk.ac.sussex.gdsc.smlm.function.cspline.CubicSplineCalculator) InitialGuess(org.apache.commons.math3.optim.InitialGuess) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) PointValuePair(org.apache.commons.math3.optim.PointValuePair) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) PositionChecker(uk.ac.sussex.gdsc.smlm.math3.optim.PositionChecker) BoundedNonLinearConjugateGradientOptimizer(uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) CustomTricubicFunction(uk.ac.sussex.gdsc.core.math.interpolation.CustomTricubicFunction)

Aggregations

OptimizationData (org.apache.commons.math3.optim.OptimizationData)4 PositionChecker (org.apache.commons.math3.optim.PositionChecker)3 InitialGuess (org.apache.commons.math3.optim.InitialGuess)2 MaxEval (org.apache.commons.math3.optim.MaxEval)2 PointValuePair (org.apache.commons.math3.optim.PointValuePair)2 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)2 SimpleValueChecker (org.apache.commons.math3.optim.SimpleValueChecker)2 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)2 ObjectiveFunctionGradient (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient)2 PositionChecker (uk.ac.sussex.gdsc.smlm.math3.optim.PositionChecker)2 BoundedNonLinearConjugateGradientOptimizer (uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer)2 ImageStack (ij.ImageStack)1 ImageProcessor (ij.process.ImageProcessor)1 Arrays (java.util.Arrays)1 Logger (java.util.logging.Logger)1 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)1 MathUnsupportedOperationException (org.apache.commons.math3.exception.MathUnsupportedOperationException)1 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)1 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)1 BaseOptimizer (org.apache.commons.math3.optim.BaseOptimizer)1