Search in sources :

Example 6 with TileRequest

use of qupath.lib.images.servers.TileRequest in project qupath by qupath.

the class PixelClassificationMeasurementManager method calculateMeasurements.

/**
 * Calculate measurements for a specified ROI if possible.
 *
 * @param roi
 * @param cachedOnly abort the mission if required tiles are not cached
 * @return
 */
synchronized MeasurementList calculateMeasurements(final ROI roi, final boolean cachedOnly) {
    Map<Integer, PathClass> classificationLabels = classifierServer.getMetadata().getClassificationLabels();
    long[] counts = null;
    // imageData.getServer();
    ImageServer<BufferedImage> server = classifierServer;
    // Check we have a suitable output type
    ImageServerMetadata.ChannelType type = classifierServer.getMetadata().getChannelType();
    if (type == ImageServerMetadata.ChannelType.FEATURE)
        return null;
    Shape shape = null;
    if (!roi.isPoint())
        shape = RoiTools.getShape(roi);
    // Get the regions we need
    Collection<TileRequest> requests;
    // For the root, we want all tile requests
    if (roi == rootROI) {
        requests = server.getTileRequestManager().getAllTileRequests();
    } else if (!roi.isEmpty()) {
        var regionRequest = RegionRequest.createInstance(server.getPath(), requestedDownsample, roi);
        requests = server.getTileRequestManager().getTileRequests(regionRequest);
    } else
        requests = Collections.emptyList();
    if (requests.isEmpty()) {
        logger.debug("Request empty for {}", roi);
        return null;
    }
    // Try to get all cached tiles - if this fails, return quickly (can't calculate measurement)
    Map<TileRequest, BufferedImage> localCache = new HashMap<>();
    for (TileRequest request : requests) {
        BufferedImage tile = null;
        try {
            tile = cachedOnly ? classifierServer.getCachedTile(request) : classifierServer.readBufferedImage(request.getRegionRequest());
        } catch (IOException e) {
            logger.error("Error requesting tile " + request, e);
        }
        if (tile == null)
            return null;
        localCache.put(request, tile);
    }
    // Calculate stained proportions
    BasicStroke stroke = null;
    byte[] mask = null;
    BufferedImage imgMask = imgTileMask.get();
    for (Map.Entry<TileRequest, BufferedImage> entry : localCache.entrySet()) {
        TileRequest region = entry.getKey();
        BufferedImage tile = entry.getValue();
        // Create a binary mask corresponding to the current tile
        if (imgMask == null || imgMask.getWidth() < tile.getWidth() || imgMask.getHeight() < tile.getHeight() || imgMask.getType() != BufferedImage.TYPE_BYTE_GRAY) {
            imgMask = new BufferedImage(tile.getWidth(), tile.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
            imgTileMask.set(imgMask);
        }
        // Get the tile, which is needed for sub-pixel accuracy
        if (roi.isLine() || roi.isArea()) {
            Graphics2D g2d = imgMask.createGraphics();
            g2d.setColor(Color.BLACK);
            g2d.fillRect(0, 0, tile.getWidth(), tile.getHeight());
            g2d.setColor(Color.WHITE);
            g2d.scale(1.0 / region.getDownsample(), 1.0 / region.getDownsample());
            g2d.translate(-region.getTileX() * region.getDownsample(), -region.getTileY() * region.getDownsample());
            if (roi.isLine()) {
                float fDownsample = (float) region.getDownsample();
                if (stroke == null || stroke.getLineWidth() != fDownsample)
                    stroke = new BasicStroke((float) fDownsample);
                g2d.setStroke(stroke);
                g2d.draw(shape);
            } else if (roi.isArea())
                g2d.fill(shape);
            g2d.dispose();
        } else if (roi.isPoint()) {
            for (var p : roi.getAllPoints()) {
                int x = (int) ((p.getX() - region.getImageX()) / region.getDownsample());
                int y = (int) ((p.getY() - region.getImageY()) / region.getDownsample());
                if (x >= 0 && y >= 0 && x < imgMask.getWidth() && y < imgMask.getHeight())
                    imgMask.getRaster().setSample(x, y, 0, 255);
            }
        }
        int h = tile.getHeight();
        int w = tile.getWidth();
        if (mask == null || mask.length != h * w)
            mask = new byte[w * h];
        int nChannels = tile.getSampleModel().getNumBands();
        try {
            switch(type) {
                case CLASSIFICATION:
                    // Calculate histogram to get labelled image counts
                    counts = BufferedImageTools.computeUnsignedIntHistogram(tile.getRaster(), counts, imgMask.getRaster());
                    break;
                case PROBABILITY:
                    // Take classification from the channel with the highest value
                    if (nChannels > 1) {
                        counts = BufferedImageTools.computeArgMaxHistogram(tile.getRaster(), counts, imgMask.getRaster());
                        break;
                    }
                // For one channel, fall through & treat as multiclass
                case MULTICLASS_PROBABILITY:
                    // For multiclass, count
                    if (counts == null)
                        counts = new long[nChannels];
                    double threshold = getProbabilityThreshold(tile.getRaster());
                    for (int c = 0; c < nChannels; c++) counts[c] += BufferedImageTools.computeAboveThresholdCounts(tile.getRaster(), c, threshold, imgMask.getRaster());
                case DEFAULT:
                case FEATURE:
                default:
                    // TODO: Consider handling other OutputTypes?
                    return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
            }
        } catch (Exception e) {
            logger.error("Error calculating classification areas", e);
            if (nChannels > 1 && type == ChannelType.CLASSIFICATION)
                logger.error("There are {} channels - are you sure this is really a classification image?", nChannels);
        }
    }
    return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
}
Also used : BasicStroke(java.awt.BasicStroke) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) Shape(java.awt.Shape) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) WeakHashMap(java.util.WeakHashMap) TileRequest(qupath.lib.images.servers.TileRequest) IOException(java.io.IOException) BufferedImage(java.awt.image.BufferedImage) IOException(java.io.IOException) Graphics2D(java.awt.Graphics2D) PathClass(qupath.lib.objects.classes.PathClass) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) WeakHashMap(java.util.WeakHashMap)

Aggregations

TileRequest (qupath.lib.images.servers.TileRequest)6 BufferedImage (java.awt.image.BufferedImage)5 RegionRequest (qupath.lib.regions.RegionRequest)5 IOException (java.io.IOException)4 ArrayList (java.util.ArrayList)4 List (java.util.List)4 ImageServerMetadata (qupath.lib.images.servers.ImageServerMetadata)4 Collection (java.util.Collection)3 Collections (java.util.Collections)3 LinkedHashMap (java.util.LinkedHashMap)3 Map (java.util.Map)3 Collectors (java.util.stream.Collectors)3 ImageServer (qupath.lib.images.servers.ImageServer)3 PathObject (qupath.lib.objects.PathObject)3 Graphics2D (java.awt.Graphics2D)2 Shape (java.awt.Shape)2 Raster (java.awt.image.Raster)2 Path (java.nio.file.Path)2 ArrayDeque (java.util.ArrayDeque)2 Deque (java.util.Deque)2