use of mpicbg.models.AbstractModel in project TrakEM2 by trakem2.
the class ElasticLayerAlignment method exec.
/**
* @param param
* @param layerRange
* @param fixedLayers
* @param emptyLayers
* @param box
* @param filter
* @param useTps true if using TPS transforms, otherwise MLS
* @throws Exception
*/
@SuppressWarnings("deprecation")
public final void exec(final Param param, final Project project, final List<Layer> layerRange, final Set<Layer> fixedLayers, final Set<Layer> emptyLayers, final Rectangle box, final boolean propagateTransformBefore, final boolean propagateTransformAfter, final Filter<Patch> filter) throws Exception {
final ExecutorService service = ExecutorProvider.getExecutorService(1.0f);
/* create tiles and models for all layers */
final ArrayList<Tile<?>> tiles = new ArrayList<Tile<?>>();
for (int i = 0; i < layerRange.size(); ++i) {
switch(param.desiredModelIndex) {
case 0:
tiles.add(new Tile<TranslationModel2D>(new TranslationModel2D()));
break;
case 1:
tiles.add(new Tile<RigidModel2D>(new RigidModel2D()));
break;
case 2:
tiles.add(new Tile<SimilarityModel2D>(new SimilarityModel2D()));
break;
case 3:
tiles.add(new Tile<AffineModel2D>(new AffineModel2D()));
break;
case 4:
tiles.add(new Tile<HomographyModel2D>(new HomographyModel2D()));
break;
default:
return;
}
}
/* collect all pairs of slices for which a model could be found */
final ArrayList<Triple<Integer, Integer, AbstractModel<?>>> pairs = new ArrayList<Triple<Integer, Integer, AbstractModel<?>>>();
if (!param.isAligned) {
preAlignStack(param, project, layerRange, box, filter, pairs);
} else {
for (int i = 0; i < layerRange.size(); ++i) {
final int range = Math.min(layerRange.size(), i + param.maxNumNeighbors + 1);
for (int j = i + 1; j < range; ++j) {
pairs.add(new Triple<Integer, Integer, AbstractModel<?>>(i, j, new TranslationModel2D()));
}
}
}
/* Elastic alignment */
/* Initialization */
final TileConfiguration initMeshes = new TileConfiguration();
final int meshWidth = (int) Math.ceil(box.width * param.layerScale);
final int meshHeight = (int) Math.ceil(box.height * param.layerScale);
final ArrayList<SpringMesh> meshes = new ArrayList<SpringMesh>(layerRange.size());
for (int i = 0; i < layerRange.size(); ++i) {
meshes.add(new SpringMesh(param.resolutionSpringMesh, meshWidth, meshHeight, param.stiffnessSpringMesh, param.maxStretchSpringMesh * param.layerScale, param.dampSpringMesh));
}
// final int blockRadius = Math.max( 32, meshWidth / p.resolutionSpringMesh / 2 );
final int blockRadius = Math.max(16, mpicbg.util.Util.roundPos(param.layerScale * param.blockRadius));
Utils.log("effective block radius = " + blockRadius);
final ArrayList<Future<BlockMatchPairCallable.BlockMatchResults>> futures = new ArrayList<Future<BlockMatchPairCallable.BlockMatchResults>>(pairs.size());
for (final Triple<Integer, Integer, AbstractModel<?>> pair : pairs) {
/* free memory */
project.getLoader().releaseAll();
final SpringMesh m1 = meshes.get(pair.a);
final SpringMesh m2 = meshes.get(pair.b);
final ArrayList<Vertex> v1 = m1.getVertices();
final ArrayList<Vertex> v2 = m2.getVertices();
final Layer layer1 = layerRange.get(pair.a);
final Layer layer2 = layerRange.get(pair.b);
final boolean layer1Fixed = fixedLayers.contains(layer1);
final boolean layer2Fixed = fixedLayers.contains(layer2);
if (!(layer1Fixed && layer2Fixed)) {
final BlockMatchPairCallable bmpc = new BlockMatchPairCallable(pair, layerRange, layer1Fixed, layer2Fixed, filter, param, v1, v2, box);
futures.add(service.submit(bmpc));
}
}
for (final Future<BlockMatchPairCallable.BlockMatchResults> future : futures) {
final BlockMatchPairCallable.BlockMatchResults results = future.get();
final Collection<PointMatch> pm12 = results.pm12, pm21 = results.pm21;
final Triple<Integer, Integer, AbstractModel<?>> pair = results.pair;
final Tile<?> t1 = tiles.get(pair.a);
final Tile<?> t2 = tiles.get(pair.b);
final SpringMesh m1 = meshes.get(pair.a);
final SpringMesh m2 = meshes.get(pair.b);
final double springConstant = 1.0 / (pair.b - pair.a);
final boolean layer1Fixed = results.layer1Fixed;
final boolean layer2Fixed = results.layer2Fixed;
if (layer1Fixed) {
initMeshes.fixTile(t1);
} else {
if (param.useLocalSmoothnessFilter) {
Utils.log(pair.a + " > " + pair.b + ": " + pm12.size() + " candidates passed local smoothness filter.");
} else {
Utils.log(pair.a + " > " + pair.b + ": found " + pm12.size() + " correspondences.");
}
for (final PointMatch pm : pm12) {
final Vertex p1 = (Vertex) pm.getP1();
final Vertex p2 = new Vertex(pm.getP2());
p1.addSpring(p2, new Spring(0, springConstant));
m2.addPassiveVertex(p2);
}
/*
* adding Tiles to the initialing TileConfiguration, adding a Tile
* multiple times does not harm because the TileConfiguration is
* backed by a Set.
*/
if (pm12.size() > pair.c.getMinNumMatches()) {
initMeshes.addTile(t1);
initMeshes.addTile(t2);
t1.connect(t2, pm12);
}
}
if (layer2Fixed)
initMeshes.fixTile(t2);
else {
if (param.useLocalSmoothnessFilter) {
Utils.log(pair.a + " < " + pair.b + ": " + pm21.size() + " candidates passed local smoothness filter.");
} else {
Utils.log(pair.a + " < " + pair.b + ": found " + pm21.size() + " correspondences.");
}
for (final PointMatch pm : pm21) {
final Vertex p1 = (Vertex) pm.getP1();
final Vertex p2 = new Vertex(pm.getP2());
p1.addSpring(p2, new Spring(0, springConstant));
m1.addPassiveVertex(p2);
}
/*
* adding Tiles to the initialing TileConfiguration, adding a Tile
* multiple times does not harm because the TileConfiguration is
* backed by a Set.
*/
if (pm21.size() > pair.c.getMinNumMatches()) {
initMeshes.addTile(t1);
initMeshes.addTile(t2);
t2.connect(t1, pm21);
}
}
Utils.log(pair.a + " <> " + pair.b + " spring constant = " + springConstant);
}
/* pre-align by optimizing a piecewise linear model */
initMeshes.optimize(param.maxEpsilon * param.layerScale, param.maxIterationsSpringMesh, param.maxPlateauwidthSpringMesh);
for (int i = 0; i < layerRange.size(); ++i) meshes.get(i).init(tiles.get(i).getModel());
/* optimize the meshes */
try {
final long t0 = System.currentTimeMillis();
Utils.log("Optimizing spring meshes...");
if (param.useLegacyOptimizer) {
Utils.log(" ...using legacy optimizer...");
SpringMesh.optimizeMeshes2(meshes, param.maxEpsilon * param.layerScale, param.maxIterationsSpringMesh, param.maxPlateauwidthSpringMesh, param.visualize);
} else {
SpringMesh.optimizeMeshes(meshes, param.maxEpsilon * param.layerScale, param.maxIterationsSpringMesh, param.maxPlateauwidthSpringMesh, param.visualize);
}
Utils.log("Done optimizing spring meshes. Took " + (System.currentTimeMillis() - t0) + " ms");
} catch (final NotEnoughDataPointsException e) {
Utils.log("There were not enough data points to get the spring mesh optimizing.");
e.printStackTrace();
return;
}
/* translate relative to bounding box */
for (final SpringMesh mesh : meshes) {
for (final PointMatch pm : mesh.getVA().keySet()) {
final Point p1 = pm.getP1();
final Point p2 = pm.getP2();
final double[] l = p1.getL();
final double[] w = p2.getW();
l[0] = l[0] / param.layerScale + box.x;
l[1] = l[1] / param.layerScale + box.y;
w[0] = w[0] / param.layerScale + box.x;
w[1] = w[1] / param.layerScale + box.y;
}
}
/* free memory */
project.getLoader().releaseAll();
final Layer first = layerRange.get(0);
final List<Layer> layers = first.getParent().getLayers();
final LayerSet ls = first.getParent();
final Area infArea = AreaUtils.infiniteArea();
final List<VectorData> vectorData = new ArrayList<VectorData>();
for (final Layer layer : ls.getLayers()) {
vectorData.addAll(Utils.castCollection(layer.getDisplayables(VectorData.class, false, true), VectorData.class, true));
}
vectorData.addAll(Utils.castCollection(ls.getZDisplayables(VectorData.class, true), VectorData.class, true));
/* transfer layer transform into patch transforms and append to patches */
if (propagateTransformBefore || propagateTransformAfter) {
if (propagateTransformBefore) {
final ThinPlateSplineTransform tps = makeTPS(meshes.get(0).getVA().keySet());
final int firstLayerIndex = first.getParent().getLayerIndex(first.getId());
for (int i = 0; i < firstLayerIndex; ++i) {
applyTransformToLayer(layers.get(i), tps, filter);
for (final VectorData vd : vectorData) {
vd.apply(layers.get(i), infArea, tps);
}
}
}
if (propagateTransformAfter) {
final Layer last = layerRange.get(layerRange.size() - 1);
final CoordinateTransform ct;
if (param.useTps)
ct = makeTPS(meshes.get(meshes.size() - 1).getVA().keySet());
else {
final MovingLeastSquaresTransform2 mls = new MovingLeastSquaresTransform2();
mls.setMatches(meshes.get(meshes.size() - 1).getVA().keySet());
ct = mls;
}
final int lastLayerIndex = last.getParent().getLayerIndex(last.getId());
for (int i = lastLayerIndex + 1; i < layers.size(); ++i) {
applyTransformToLayer(layers.get(i), ct, filter);
for (final VectorData vd : vectorData) {
vd.apply(layers.get(i), infArea, ct);
}
}
}
}
for (int l = 0; l < layerRange.size(); ++l) {
IJ.showStatus("Applying transformation to patches ...");
IJ.showProgress(0, layerRange.size());
final Layer layer = layerRange.get(l);
final ThinPlateSplineTransform tps = makeTPS(meshes.get(l).getVA().keySet());
applyTransformToLayer(layer, tps, filter);
for (final VectorData vd : vectorData) {
vd.apply(layer, infArea, tps);
}
if (Thread.interrupted()) {
Utils.log("Interrupted during applying transformations to patches. No all patches have been updated. Re-generate mipmaps manually.");
}
IJ.showProgress(l + 1, layerRange.size());
}
/* update patch mipmaps */
final int firstLayerIndex;
final int lastLayerIndex;
if (propagateTransformBefore)
firstLayerIndex = 0;
else {
firstLayerIndex = first.getParent().getLayerIndex(first.getId());
}
if (propagateTransformAfter)
lastLayerIndex = layers.size() - 1;
else {
final Layer last = layerRange.get(layerRange.size() - 1);
lastLayerIndex = last.getParent().getLayerIndex(last.getId());
}
for (int i = firstLayerIndex; i <= lastLayerIndex; ++i) {
final Layer layer = layers.get(i);
if (!(emptyLayers.contains(layer) || fixedLayers.contains(layer))) {
for (final Patch patch : AlignmentUtils.filterPatches(layer, filter)) patch.updateMipMaps();
}
}
Utils.log("Done.");
}
use of mpicbg.models.AbstractModel in project TrakEM2 by trakem2.
the class ElasticLayerAlignment method preAlignStack.
private void preAlignStack(final Param param, final Project project, final List<Layer> layerRange, final Rectangle box, final Filter<Patch> filter, final ArrayList<Triple<Integer, Integer, AbstractModel<?>>> pairs) {
final double scale = Math.min(1.0, Math.min((double) param.ppm.sift.maxOctaveSize / (double) box.width, (double) param.ppm.sift.maxOctaveSize / (double) box.height));
/* extract and save features, overwrite cached files if requested */
try {
AlignmentUtils.extractAndSaveLayerFeatures(layerRange, box, scale, filter, param.ppm.sift, param.ppm.clearCache, param.ppm.maxNumThreadsSift);
} catch (final Exception e) {
return;
}
/* match and filter feature correspondences */
int numFailures = 0;
final double pointMatchScale = param.layerScale / scale;
for (int i = 0; i < layerRange.size(); ++i) {
final ArrayList<Thread> threads = new ArrayList<Thread>(param.maxNumThreads);
final int sliceA = i;
final Layer layerA = layerRange.get(i);
final int range = Math.min(layerRange.size(), i + param.maxNumNeighbors + 1);
final String layerNameA = layerName(layerA);
for (int j = i + 1; j < range; ) J: {
final int numThreads = Math.min(param.maxNumThreads, range - j);
final ArrayList<Triple<Integer, Integer, AbstractModel<?>>> models = new ArrayList<Triple<Integer, Integer, AbstractModel<?>>>(numThreads);
for (int k = 0; k < numThreads; ++k) models.add(null);
for (int t = 0; t < numThreads && j < range; ++t, ++j) {
final int ti = t;
final int sliceB = j;
final Layer layerB = layerRange.get(j);
final String layerNameB = layerName(layerB);
final Thread thread = new Thread() {
@Override
public void run() {
IJ.showProgress(sliceA, layerRange.size() - 1);
Utils.log("matching " + layerNameB + " -> " + layerNameA + "...");
ArrayList<PointMatch> candidates = null;
if (!param.ppm.clearCache)
candidates = mpicbg.trakem2.align.Util.deserializePointMatches(project, param.ppm, "layer", layerB.getId(), layerA.getId());
if (null == candidates) {
final ArrayList<Feature> fs1 = mpicbg.trakem2.align.Util.deserializeFeatures(project, param.ppm.sift, "layer", layerA.getId());
final ArrayList<Feature> fs2 = mpicbg.trakem2.align.Util.deserializeFeatures(project, param.ppm.sift, "layer", layerB.getId());
candidates = new ArrayList<PointMatch>(FloatArray2DSIFT.createMatches(fs2, fs1, param.ppm.rod));
/* scale the candidates */
for (final PointMatch pm : candidates) {
final Point p1 = pm.getP1();
final Point p2 = pm.getP2();
final double[] l1 = p1.getL();
final double[] w1 = p1.getW();
final double[] l2 = p2.getL();
final double[] w2 = p2.getW();
l1[0] *= pointMatchScale;
l1[1] *= pointMatchScale;
w1[0] *= pointMatchScale;
w1[1] *= pointMatchScale;
l2[0] *= pointMatchScale;
l2[1] *= pointMatchScale;
w2[0] *= pointMatchScale;
w2[1] *= pointMatchScale;
}
if (!mpicbg.trakem2.align.Util.serializePointMatches(project, param.ppm, "layer", layerB.getId(), layerA.getId(), candidates))
Utils.log("Could not store point match candidates for layers " + layerNameB + " and " + layerNameA + ".");
}
AbstractModel<?> model;
switch(param.expectedModelIndex) {
case 0:
model = new TranslationModel2D();
break;
case 1:
model = new RigidModel2D();
break;
case 2:
model = new SimilarityModel2D();
break;
case 3:
model = new AffineModel2D();
break;
case 4:
model = new HomographyModel2D();
break;
default:
return;
}
final ArrayList<PointMatch> inliers = new ArrayList<PointMatch>();
boolean modelFound;
boolean again = false;
try {
do {
again = false;
modelFound = model.filterRansac(candidates, inliers, 1000, param.maxEpsilon * param.layerScale, param.minInlierRatio, param.minNumInliers, 3);
if (modelFound && param.rejectIdentity) {
final ArrayList<Point> points = new ArrayList<Point>();
PointMatch.sourcePoints(inliers, points);
if (Transforms.isIdentity(model, points, param.identityTolerance * param.layerScale)) {
IJ.log("Identity transform for " + inliers.size() + " matches rejected.");
candidates.removeAll(inliers);
inliers.clear();
again = true;
}
}
} while (again);
} catch (final NotEnoughDataPointsException e) {
modelFound = false;
}
if (modelFound) {
Utils.log(layerNameB + " -> " + layerNameA + ": " + inliers.size() + " corresponding features with an average displacement of " + (PointMatch.meanDistance(inliers) / param.layerScale) + "px identified.");
Utils.log("Estimated transformation model: " + model);
models.set(ti, new Triple<Integer, Integer, AbstractModel<?>>(sliceA, sliceB, model));
} else {
Utils.log(layerNameB + " -> " + layerNameA + ": no correspondences found.");
return;
}
}
};
threads.add(thread);
thread.start();
}
try {
for (final Thread thread : threads) thread.join();
} catch (final InterruptedException e) {
Utils.log("Establishing feature correspondences interrupted.");
for (final Thread thread : threads) thread.interrupt();
try {
for (final Thread thread : threads) thread.join();
} catch (final InterruptedException f) {
}
return;
}
threads.clear();
/* collect successfully matches pairs and break the search on gaps */
for (int t = 0; t < models.size(); ++t) {
final Triple<Integer, Integer, AbstractModel<?>> pair = models.get(t);
if (pair == null) {
if (++numFailures > param.maxNumFailures) {
break J;
}
} else {
numFailures = 0;
pairs.add(pair);
}
}
}
}
}
Aggregations