Search in sources :

Example 1 with PhaseCorrelationPeak

use of mpicbg.imglib.algorithm.fft.PhaseCorrelationPeak in project TrakEM2 by trakem2.

the class StitchingTEM method correlate.

/**
 * @param scale For optimizing the speed of phase- and cross-correlation.
 * @param percent_overlap The minimum chunk of adjacent images to compare with, will automatically and gradually increase to 100% if no good matches are found.
 * @return a double[4] array containing:<ul>
 * <li>x2: relative X position of the second Patch</li>
 * <li>y2: relative Y position of the second Patch</li>
 * <li>flag: ERROR or SUCCESS</li>
 * <li>R: cross-correlation coefficient</li>
 * </ul>
 */
public static double[] correlate(final Patch base, final Patch moving, final float percent_overlap, final double scale, final int direction, final double default_dx, final double default_dy, final double min_R) {
    // PhaseCorrelation2D pc = null;
    final double R = -2;
    // final int limit = 5; // number of peaks to check in the PhaseCorrelation results
    // final float min_R = 0.40f; // minimum R for phase-correlation to be considered good
    // half this min_R will be considered good for cross-correlation
    // Iterate until PhaseCorrelation correlation coefficient R is over 0.5, or there's no more
    // image overlap to feed
    // Utils.log2("min_R: " + min_R);
    ImageProcessor ip1, ip2;
    final Rectangle b1 = base.getBoundingBox(null);
    final Rectangle b2 = moving.getBoundingBox(null);
    final int w1 = b1.width, h1 = b1.height, w2 = b2.width, h2 = b2.height;
    Roi roi1 = null, roi2 = null;
    float overlap = percent_overlap;
    double dx = default_dx, dy = default_dy;
    do {
        // create rois for the stripes
        switch(direction) {
            case TOP_BOTTOM:
                // bottom
                roi1 = new Roi(0, h1 - (int) (h1 * overlap), w1, (int) (h1 * overlap));
                // top
                roi2 = new Roi(0, 0, w2, (int) (h2 * overlap));
                break;
            case LEFT_RIGHT:
                // right
                roi1 = new Roi(w1 - (int) (w1 * overlap), 0, (int) (w1 * overlap), h1);
                // left
                roi2 = new Roi(0, 0, (int) (w2 * overlap), h2);
                break;
        }
        // Utils.log2("roi1: " + roi1);
        // Utils.log2("roi2: " + roi2);
        // will apply the transform if necessary
        ip1 = makeStripe(base, roi1, scale);
        ip2 = makeStripe(moving, roi2, scale);
        // new ImagePlus("roi1", ip1).show();
        // new ImagePlus("roi2", ip2).show();
        ip1.setPixels(ImageFilter.computeGaussianFastMirror(new FloatArray2D((float[]) ip1.getPixels(), ip1.getWidth(), ip1.getHeight()), 1.0).data);
        ip2.setPixels(ImageFilter.computeGaussianFastMirror(new FloatArray2D((float[]) ip2.getPixels(), ip2.getWidth(), ip2.getHeight()), 1.0).data);
        // 
        final ImagePlus imp1 = new ImagePlus("", ip1);
        final ImagePlus imp2 = new ImagePlus("", ip2);
        final PhaseCorrelationCalculator t = new PhaseCorrelationCalculator(imp1, imp2);
        final PhaseCorrelationPeak peak = t.getPeak();
        final double resultR = peak.getCrossCorrelationPeak();
        final int[] peackPostion = peak.getPosition();
        final java.awt.Point shift = new java.awt.Point(peackPostion[0], peackPostion[1]);
        // Utils.log2("overlap: " + overlap + " R: " + resultR + " shift: " + shift + " dx,dy: " + dx + ", " + dy);
        if (resultR >= min_R) {
            // success
            final int success = SUCCESS;
            switch(direction) {
                case TOP_BOTTOM:
                    // boundary checks:
                    // if (shift.y/scale > default_dy) success = ERROR;
                    dx = shift.x / scale;
                    dy = roi1.getBounds().y + shift.y / scale;
                    break;
                case LEFT_RIGHT:
                    // boundary checks:
                    // if (shift.x/scale > default_dx) success = ERROR;
                    dx = roi1.getBounds().x + shift.x / scale;
                    dy = shift.y / scale;
                    break;
            }
            // Utils.log2("R: " + resultR + " shift: " + shift + " dx,dy: " + dx + ", " + dy);
            return new double[] { dx, dy, success, resultR };
        }
        // new ImagePlus("roi1", ip1.duplicate()).show();
        // new ImagePlus("roi2", ip2.duplicate()).show();
        // try { Thread.sleep(1000000000); } catch (Exception e) {}
        // increase for next iteration
        // increments of 10%
        overlap += 0.10;
    } while (R < min_R && Math.abs(overlap - 1.0f) < 0.001f);
    // Phase-correlation failed, fall back to cross-correlation with a safe overlap
    overlap = percent_overlap * 2;
    if (overlap > 1.0f)
        overlap = 1.0f;
    switch(direction) {
        case TOP_BOTTOM:
            // bottom
            roi1 = new Roi(0, h1 - (int) (h1 * overlap), w1, (int) (h1 * overlap));
            // top
            roi2 = new Roi(0, 0, w2, (int) (h2 * overlap));
            break;
        case LEFT_RIGHT:
            // right
            roi1 = new Roi(w1 - (int) (w1 * overlap), 0, (int) (w1 * overlap), h1);
            // left
            roi2 = new Roi(0, 0, (int) (w2 * overlap), h2);
            break;
    }
    // use one third of the size used for phase-correlation though! Otherwise, it may take FOREVER
    final double scale_cc = scale / 3.0f;
    ip1 = makeStripe(base, roi1, scale_cc);
    ip2 = makeStripe(moving, roi2, scale_cc);
    // gaussian blur them before cross-correlation
    ip1.setPixels(ImageFilter.computeGaussianFastMirror(new FloatArray2D((float[]) ip1.getPixels(), ip1.getWidth(), ip1.getHeight()), 1f).data);
    ip2.setPixels(ImageFilter.computeGaussianFastMirror(new FloatArray2D((float[]) ip2.getPixels(), ip2.getWidth(), ip2.getHeight()), 1f).data);
    // new ImagePlus("CC roi1", ip1).show();
    // new ImagePlus("CC roi2", ip2).show();
    final CrossCorrelation2D cc = new CrossCorrelation2D(ip1, ip2, false);
    double[] cc_result = null;
    switch(direction) {
        case TOP_BOTTOM:
            cc_result = cc.computeCrossCorrelationMT(0.9, 0.3, false);
            break;
        case LEFT_RIGHT:
            cc_result = cc.computeCrossCorrelationMT(0.3, 0.9, false);
            break;
    }
    if (cc_result[2] > min_R / 2) {
        // accepting if R is above half the R accepted for Phase Correlation
        // success
        final int success = SUCCESS;
        switch(direction) {
            case TOP_BOTTOM:
                // boundary checks:
                // if (cc_result[1]/scale_cc > default_dy) success = ERROR;
                dx = cc_result[0] / scale_cc;
                dy = roi1.getBounds().y + cc_result[1] / scale_cc;
                break;
            case LEFT_RIGHT:
                // boundary checks:
                // if (cc_result[0]/scale_cc > default_dx) success = ERROR;
                dx = roi1.getBounds().x + cc_result[0] / scale_cc;
                dy = cc_result[1] / scale_cc;
                break;
        }
        // Utils.log2("\trois: \t" + roi1 + "\n\t\t" + roi2);
        return new double[] { dx, dy, success, cc_result[2] };
    }
    // Utils.log2("Using default");
    return new double[] { default_dx, default_dy, ERROR, 0 };
// / ABOVE: boundary checks don't work if default_dx,dy are zero! And may actually be harmful in anycase
}
Also used : FloatArray2D(mpi.fruitfly.math.datastructures.FloatArray2D) Rectangle(java.awt.Rectangle) Point(mpicbg.models.Point) Roi(ij.gui.Roi) ImagePlus(ij.ImagePlus) Point(mpicbg.models.Point) ImageProcessor(ij.process.ImageProcessor) PhaseCorrelationPeak(mpicbg.imglib.algorithm.fft.PhaseCorrelationPeak) CrossCorrelation2D(mpi.fruitfly.registration.CrossCorrelation2D)

Aggregations

ImagePlus (ij.ImagePlus)1 Roi (ij.gui.Roi)1 ImageProcessor (ij.process.ImageProcessor)1 Rectangle (java.awt.Rectangle)1 FloatArray2D (mpi.fruitfly.math.datastructures.FloatArray2D)1 CrossCorrelation2D (mpi.fruitfly.registration.CrossCorrelation2D)1 PhaseCorrelationPeak (mpicbg.imglib.algorithm.fft.PhaseCorrelationPeak)1 Point (mpicbg.models.Point)1