use of qupath.lib.images.servers.TileRequest in project qupath by qupath.
the class PixelClassificationMeasurementManager method calculateMeasurements.
/**
* Calculate measurements for a specified ROI if possible.
*
* @param roi
* @param cachedOnly abort the mission if required tiles are not cached
* @return
*/
synchronized MeasurementList calculateMeasurements(final ROI roi, final boolean cachedOnly) {
Map<Integer, PathClass> classificationLabels = classifierServer.getMetadata().getClassificationLabels();
long[] counts = null;
// imageData.getServer();
ImageServer<BufferedImage> server = classifierServer;
// Check we have a suitable output type
ImageServerMetadata.ChannelType type = classifierServer.getMetadata().getChannelType();
if (type == ImageServerMetadata.ChannelType.FEATURE)
return null;
Shape shape = null;
if (!roi.isPoint())
shape = RoiTools.getShape(roi);
// Get the regions we need
Collection<TileRequest> requests;
// For the root, we want all tile requests
if (roi == rootROI) {
requests = server.getTileRequestManager().getAllTileRequests();
} else if (!roi.isEmpty()) {
var regionRequest = RegionRequest.createInstance(server.getPath(), requestedDownsample, roi);
requests = server.getTileRequestManager().getTileRequests(regionRequest);
} else
requests = Collections.emptyList();
if (requests.isEmpty()) {
logger.debug("Request empty for {}", roi);
return null;
}
// Try to get all cached tiles - if this fails, return quickly (can't calculate measurement)
Map<TileRequest, BufferedImage> localCache = new HashMap<>();
for (TileRequest request : requests) {
BufferedImage tile = null;
try {
tile = cachedOnly ? classifierServer.getCachedTile(request) : classifierServer.readBufferedImage(request.getRegionRequest());
} catch (IOException e) {
logger.error("Error requesting tile " + request, e);
}
if (tile == null)
return null;
localCache.put(request, tile);
}
// Calculate stained proportions
BasicStroke stroke = null;
byte[] mask = null;
BufferedImage imgMask = imgTileMask.get();
for (Map.Entry<TileRequest, BufferedImage> entry : localCache.entrySet()) {
TileRequest region = entry.getKey();
BufferedImage tile = entry.getValue();
// Create a binary mask corresponding to the current tile
if (imgMask == null || imgMask.getWidth() < tile.getWidth() || imgMask.getHeight() < tile.getHeight() || imgMask.getType() != BufferedImage.TYPE_BYTE_GRAY) {
imgMask = new BufferedImage(tile.getWidth(), tile.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
imgTileMask.set(imgMask);
}
// Get the tile, which is needed for sub-pixel accuracy
if (roi.isLine() || roi.isArea()) {
Graphics2D g2d = imgMask.createGraphics();
g2d.setColor(Color.BLACK);
g2d.fillRect(0, 0, tile.getWidth(), tile.getHeight());
g2d.setColor(Color.WHITE);
g2d.scale(1.0 / region.getDownsample(), 1.0 / region.getDownsample());
g2d.translate(-region.getTileX() * region.getDownsample(), -region.getTileY() * region.getDownsample());
if (roi.isLine()) {
float fDownsample = (float) region.getDownsample();
if (stroke == null || stroke.getLineWidth() != fDownsample)
stroke = new BasicStroke((float) fDownsample);
g2d.setStroke(stroke);
g2d.draw(shape);
} else if (roi.isArea())
g2d.fill(shape);
g2d.dispose();
} else if (roi.isPoint()) {
for (var p : roi.getAllPoints()) {
int x = (int) ((p.getX() - region.getImageX()) / region.getDownsample());
int y = (int) ((p.getY() - region.getImageY()) / region.getDownsample());
if (x >= 0 && y >= 0 && x < imgMask.getWidth() && y < imgMask.getHeight())
imgMask.getRaster().setSample(x, y, 0, 255);
}
}
int h = tile.getHeight();
int w = tile.getWidth();
if (mask == null || mask.length != h * w)
mask = new byte[w * h];
int nChannels = tile.getSampleModel().getNumBands();
try {
switch(type) {
case CLASSIFICATION:
// Calculate histogram to get labelled image counts
counts = BufferedImageTools.computeUnsignedIntHistogram(tile.getRaster(), counts, imgMask.getRaster());
break;
case PROBABILITY:
// Take classification from the channel with the highest value
if (nChannels > 1) {
counts = BufferedImageTools.computeArgMaxHistogram(tile.getRaster(), counts, imgMask.getRaster());
break;
}
// For one channel, fall through & treat as multiclass
case MULTICLASS_PROBABILITY:
// For multiclass, count
if (counts == null)
counts = new long[nChannels];
double threshold = getProbabilityThreshold(tile.getRaster());
for (int c = 0; c < nChannels; c++) counts[c] += BufferedImageTools.computeAboveThresholdCounts(tile.getRaster(), c, threshold, imgMask.getRaster());
case DEFAULT:
case FEATURE:
default:
// TODO: Consider handling other OutputTypes?
return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
}
} catch (Exception e) {
logger.error("Error calculating classification areas", e);
if (nChannels > 1 && type == ChannelType.CLASSIFICATION)
logger.error("There are {} channels - are you sure this is really a classification image?", nChannels);
}
}
return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
}
Aggregations