Search in sources :

Example 41 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class ClassifierBuilderPane method updateRetainedObjectsMap.

/**
 * Update the retained objects map using the data from the current image.
 */
private void updateRetainedObjectsMap() {
    PathObjectHierarchy hierarchy = getHierarchy();
    if (hierarchy != null) {
        Map<PathClass, List<PathObject>> mapCurrent = PathClassificationLabellingHelper.getClassificationMap(hierarchy, paramsUpdate.getBooleanParameterValue("trainFromPoints"));
        // Add in any retained objects, if we have some
        PathClassificationLabellingHelper.countObjectsInMap(mapCurrent);
        // int retainedImageCount = retainedObjectsMap.addToTrainingMap(map, getImageData().getServerPath());
        retainedObjectsMap.put(getMapKey(getImageData()), mapCurrent);
        updateRetainedObjectsLabel();
    }
}
Also used : PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) PathClass(qupath.lib.objects.classes.PathClass) ParameterList(qupath.lib.plugins.parameters.ParameterList) List(java.util.List) ArrayList(java.util.ArrayList)

Example 42 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class RetainedTrainingObjects method readExternal.

@SuppressWarnings("unchecked")
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
    // Version
    Object version = in.readObject();
    if (!("Training objects v0.1".equals(version)))
        throw new IOException("Version not supported!");
    Object data = in.readObject();
    Map<String, Map<PathClass, List<PathObject>>> mapRead = (Map<String, Map<PathClass, List<PathObject>>>) data;
    // Loop through and ensure that we are using the 'current' PathClasses - not the deserialized ones
    for (String key : mapRead.keySet()) {
        Map<PathClass, List<PathObject>> oldMap = mapRead.get(key);
        Map<PathClass, List<PathObject>> newMap = new HashMap<>();
        for (Entry<PathClass, List<PathObject>> entry : oldMap.entrySet()) {
            PathClass pathClass = entry.getKey();
            newMap.put(PathClassFactory.getPathClass(pathClass.getName(), pathClass.getColor()), entry.getValue());
        }
        retainedObjectsMap.put(key, newMap);
    }
}
Also used : PathClass(qupath.lib.objects.classes.PathClass) PathObject(qupath.lib.objects.PathObject) HashMap(java.util.HashMap) PathObject(qupath.lib.objects.PathObject) ArrayList(java.util.ArrayList) List(java.util.List) IOException(java.io.IOException) HashMap(java.util.HashMap) TreeMap(java.util.TreeMap) Map(java.util.Map)

Example 43 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class PixelClassificationMeasurementManager method updateMeasurements.

private synchronized MeasurementList updateMeasurements(Map<Integer, PathClass> classificationLabels, long[] counts, double pixelArea, String pixelAreaUnits) {
    long total = counts == null ? 0L : GeneralTools.sum(counts);
    Collection<PathClass> pathClasses = new LinkedHashSet<>(classificationLabels.values());
    boolean addNames = measurementNames == null;
    List<String> tempList = null;
    int nMeasurements = pathClasses.size() * 2;
    if (!isMulticlass)
        nMeasurements += 2;
    if (addNames) {
        tempList = new ArrayList<>();
        measurementNames = Collections.unmodifiableList(tempList);
    } else
        nMeasurements = measurementNames.size();
    MeasurementList measurementList = MeasurementListFactory.createMeasurementList(nMeasurements, MeasurementListType.DOUBLE);
    Set<PathClass> ignored = pathClasses.stream().filter(p -> p == null || PathClassTools.isIgnoredClass(p)).collect(Collectors.toSet());
    // Calculate totals for all non-ignored classes
    Map<PathClass, Long> pathClassTotals = new LinkedHashMap<>();
    long totalWithoutIgnored = 0L;
    if (counts != null) {
        for (var entry : classificationLabels.entrySet()) {
            PathClass pathClass = entry.getValue();
            // Skip background channels
            if (pathClass == null || ignored.contains(pathClass))
                continue;
            int c = entry.getKey();
            long temp = counts == null || c >= counts.length ? 0L : counts[c];
            totalWithoutIgnored += temp;
            pathClassTotals.put(pathClass, pathClassTotals.getOrDefault(pathClass, 0L) + temp);
        }
    } else {
        for (var pathClass : pathClasses) if (pathClass != null && !ignored.contains(pathClass))
            pathClassTotals.put(pathClass, 0L);
    }
    // Add measurements for classes
    for (var entry : pathClassTotals.entrySet()) {
        var pathClass = entry.getKey();
        String name = pathClass.toString();
        String namePercentage = name + " %";
        String nameArea = name + " area " + pixelAreaUnits;
        if (tempList != null) {
            if (pathClassTotals.size() > 1)
                tempList.add(namePercentage);
            tempList.add(nameArea);
        }
        if (counts != null) {
            long count = entry.getValue();
            if (pathClassTotals.size() > 1)
                measurementList.putMeasurement(namePercentage, (double) count / totalWithoutIgnored * 100.0);
            if (!Double.isNaN(pixelArea)) {
                measurementList.putMeasurement(nameArea, count * pixelArea);
            }
        }
    }
    // Add total area (useful as a check)
    String nameArea = "Total annotated area " + pixelAreaUnits;
    String nameAreaWithoutIgnored = "Total quantified area " + pixelAreaUnits;
    if (counts != null && !Double.isNaN(pixelArea)) {
        if (tempList != null) {
            tempList.add(nameArea);
            tempList.add(nameAreaWithoutIgnored);
        }
        measurementList.putMeasurement(nameArea, totalWithoutIgnored * pixelArea);
        measurementList.putMeasurement(nameAreaWithoutIgnored, total * pixelArea);
    }
    measurementList.close();
    return measurementList;
}
Also used : LinkedHashSet(java.util.LinkedHashSet) Color(java.awt.Color) ImageServer(qupath.lib.images.servers.ImageServer) PathClassTools(qupath.lib.objects.classes.PathClassTools) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) MeasurementList(qupath.lib.measurements.MeasurementList) LinkedHashMap(java.util.LinkedHashMap) ROIs(qupath.lib.roi.ROIs) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) Graphics2D(java.awt.Graphics2D) Map(java.util.Map) TileRequest(qupath.lib.images.servers.TileRequest) BufferedImageTools(qupath.lib.awt.common.BufferedImageTools) MeasurementListFactory(qupath.lib.measurements.MeasurementListFactory) LinkedHashSet(java.util.LinkedHashSet) WeakHashMap(java.util.WeakHashMap) Shape(java.awt.Shape) MeasurementListType(qupath.lib.measurements.MeasurementList.MeasurementListType) RoiTools(qupath.lib.roi.RoiTools) Logger(org.slf4j.Logger) BufferedImage(java.awt.image.BufferedImage) GeneralTools(qupath.lib.common.GeneralTools) RegionRequest(qupath.lib.regions.RegionRequest) Collection(java.util.Collection) PathClass(qupath.lib.objects.classes.PathClass) Set(java.util.Set) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) PathObject(qupath.lib.objects.PathObject) ROI(qupath.lib.roi.interfaces.ROI) List(java.util.List) PixelCalibration(qupath.lib.images.servers.PixelCalibration) ImagePlane(qupath.lib.regions.ImagePlane) BasicStroke(java.awt.BasicStroke) WritableRaster(java.awt.image.WritableRaster) Collections(java.util.Collections) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) DataBuffer(java.awt.image.DataBuffer) MeasurementList(qupath.lib.measurements.MeasurementList) LinkedHashMap(java.util.LinkedHashMap) PathClass(qupath.lib.objects.classes.PathClass)

Example 44 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class PixelClassificationMeasurementManager method calculateMeasurements.

/**
 * Calculate measurements for a specified ROI if possible.
 *
 * @param roi
 * @param cachedOnly abort the mission if required tiles are not cached
 * @return
 */
synchronized MeasurementList calculateMeasurements(final ROI roi, final boolean cachedOnly) {
    Map<Integer, PathClass> classificationLabels = classifierServer.getMetadata().getClassificationLabels();
    long[] counts = null;
    // imageData.getServer();
    ImageServer<BufferedImage> server = classifierServer;
    // Check we have a suitable output type
    ImageServerMetadata.ChannelType type = classifierServer.getMetadata().getChannelType();
    if (type == ImageServerMetadata.ChannelType.FEATURE)
        return null;
    Shape shape = null;
    if (!roi.isPoint())
        shape = RoiTools.getShape(roi);
    // Get the regions we need
    Collection<TileRequest> requests;
    // For the root, we want all tile requests
    if (roi == rootROI) {
        requests = server.getTileRequestManager().getAllTileRequests();
    } else if (!roi.isEmpty()) {
        var regionRequest = RegionRequest.createInstance(server.getPath(), requestedDownsample, roi);
        requests = server.getTileRequestManager().getTileRequests(regionRequest);
    } else
        requests = Collections.emptyList();
    if (requests.isEmpty()) {
        logger.debug("Request empty for {}", roi);
        return null;
    }
    // Try to get all cached tiles - if this fails, return quickly (can't calculate measurement)
    Map<TileRequest, BufferedImage> localCache = new HashMap<>();
    for (TileRequest request : requests) {
        BufferedImage tile = null;
        try {
            tile = cachedOnly ? classifierServer.getCachedTile(request) : classifierServer.readBufferedImage(request.getRegionRequest());
        } catch (IOException e) {
            logger.error("Error requesting tile " + request, e);
        }
        if (tile == null)
            return null;
        localCache.put(request, tile);
    }
    // Calculate stained proportions
    BasicStroke stroke = null;
    byte[] mask = null;
    BufferedImage imgMask = imgTileMask.get();
    for (Map.Entry<TileRequest, BufferedImage> entry : localCache.entrySet()) {
        TileRequest region = entry.getKey();
        BufferedImage tile = entry.getValue();
        // Create a binary mask corresponding to the current tile
        if (imgMask == null || imgMask.getWidth() < tile.getWidth() || imgMask.getHeight() < tile.getHeight() || imgMask.getType() != BufferedImage.TYPE_BYTE_GRAY) {
            imgMask = new BufferedImage(tile.getWidth(), tile.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
            imgTileMask.set(imgMask);
        }
        // Get the tile, which is needed for sub-pixel accuracy
        if (roi.isLine() || roi.isArea()) {
            Graphics2D g2d = imgMask.createGraphics();
            g2d.setColor(Color.BLACK);
            g2d.fillRect(0, 0, tile.getWidth(), tile.getHeight());
            g2d.setColor(Color.WHITE);
            g2d.scale(1.0 / region.getDownsample(), 1.0 / region.getDownsample());
            g2d.translate(-region.getTileX() * region.getDownsample(), -region.getTileY() * region.getDownsample());
            if (roi.isLine()) {
                float fDownsample = (float) region.getDownsample();
                if (stroke == null || stroke.getLineWidth() != fDownsample)
                    stroke = new BasicStroke((float) fDownsample);
                g2d.setStroke(stroke);
                g2d.draw(shape);
            } else if (roi.isArea())
                g2d.fill(shape);
            g2d.dispose();
        } else if (roi.isPoint()) {
            for (var p : roi.getAllPoints()) {
                int x = (int) ((p.getX() - region.getImageX()) / region.getDownsample());
                int y = (int) ((p.getY() - region.getImageY()) / region.getDownsample());
                if (x >= 0 && y >= 0 && x < imgMask.getWidth() && y < imgMask.getHeight())
                    imgMask.getRaster().setSample(x, y, 0, 255);
            }
        }
        int h = tile.getHeight();
        int w = tile.getWidth();
        if (mask == null || mask.length != h * w)
            mask = new byte[w * h];
        int nChannels = tile.getSampleModel().getNumBands();
        try {
            switch(type) {
                case CLASSIFICATION:
                    // Calculate histogram to get labelled image counts
                    counts = BufferedImageTools.computeUnsignedIntHistogram(tile.getRaster(), counts, imgMask.getRaster());
                    break;
                case PROBABILITY:
                    // Take classification from the channel with the highest value
                    if (nChannels > 1) {
                        counts = BufferedImageTools.computeArgMaxHistogram(tile.getRaster(), counts, imgMask.getRaster());
                        break;
                    }
                // For one channel, fall through & treat as multiclass
                case MULTICLASS_PROBABILITY:
                    // For multiclass, count
                    if (counts == null)
                        counts = new long[nChannels];
                    double threshold = getProbabilityThreshold(tile.getRaster());
                    for (int c = 0; c < nChannels; c++) counts[c] += BufferedImageTools.computeAboveThresholdCounts(tile.getRaster(), c, threshold, imgMask.getRaster());
                case DEFAULT:
                case FEATURE:
                default:
                    // TODO: Consider handling other OutputTypes?
                    return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
            }
        } catch (Exception e) {
            logger.error("Error calculating classification areas", e);
            if (nChannels > 1 && type == ChannelType.CLASSIFICATION)
                logger.error("There are {} channels - are you sure this is really a classification image?", nChannels);
        }
    }
    return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
}
Also used : BasicStroke(java.awt.BasicStroke) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) Shape(java.awt.Shape) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) WeakHashMap(java.util.WeakHashMap) TileRequest(qupath.lib.images.servers.TileRequest) IOException(java.io.IOException) BufferedImage(java.awt.image.BufferedImage) IOException(java.io.IOException) Graphics2D(java.awt.Graphics2D) PathClass(qupath.lib.objects.classes.PathClass) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) WeakHashMap(java.util.WeakHashMap)

Example 45 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class PixelClassifierTools method createObjectsFromPixelClassifier.

/**
 * Create objects based upon an {@link ImageServer} that provides classification or probability output.
 *
 * @param server image server providing pixels from which objects should be created
 * @param labels classification labels; if null, these will be taken from ImageServer#getMetadata() and all non-ignored classifications will be used.
 * 		   Providing a map makes it possible to explicitly exclude some classifications.
 * @param roi region of interest in which objects should be created (optional; if null, the entire image is used)
 * @param creator function to create an object from a ROI (e.g. annotation or detection)
 * @param minArea minimum area for an object fragment to retain, in calibrated units based on the pixel calibration
 * @param minHoleArea minimum area for a hole to fill, in calibrated units based on the pixel calibration
 * @param doSplit if true, split connected regions into separate objects
 * @return the objects created within the ROI
 * @throws IOException
 */
public static Collection<PathObject> createObjectsFromPixelClassifier(ImageServer<BufferedImage> server, Map<Integer, PathClass> labels, ROI roi, Function<ROI, ? extends PathObject> creator, double minArea, double minHoleArea, boolean doSplit) throws IOException {
    // We need classification labels to do anything
    if (labels == null)
        labels = parseClassificationLabels(server.getMetadata().getClassificationLabels(), false);
    if (labels == null || labels.isEmpty())
        throw new IllegalArgumentException("Cannot create objects for server - no classification labels are available!");
    ChannelThreshold[] thresholds = labels.entrySet().stream().map(e -> ChannelThreshold.create(e.getKey())).toArray(ChannelThreshold[]::new);
    if (roi != null && !roi.isArea()) {
        logger.warn("Cannot create objects for non-area ROIs");
        return Collections.emptyList();
    }
    Geometry clipArea = roi == null ? null : roi.getGeometry();
    // Identify regions for selected ROI or entire image
    // This is a list because it might need to handle multiple z-slices or timepoints
    List<RegionRequest> regionRequests;
    if (roi != null) {
        var request = RegionRequest.createInstance(server.getPath(), server.getDownsampleForResolution(0), roi);
        regionRequests = Collections.singletonList(request);
    } else {
        regionRequests = RegionRequest.createAllRequests(server, server.getDownsampleForResolution(0));
    }
    double pixelArea = server.getPixelCalibration().getPixelWidth().doubleValue() * server.getPixelCalibration().getPixelHeight().doubleValue();
    double minAreaPixels = minArea / pixelArea;
    double minHoleAreaPixels = minHoleArea / pixelArea;
    // Create output array
    var pathObjects = new ArrayList<PathObject>();
    // Loop through region requests (usually 1, unless we have a z-stack or time series)
    for (RegionRequest regionRequest : regionRequests) {
        Map<Integer, Geometry> geometryMap = ContourTracing.traceGeometries(server, regionRequest, clipArea, thresholds);
        var labelMap = labels;
        pathObjects.addAll(geometryMap.entrySet().parallelStream().flatMap(e -> geometryToObjects(e.getValue(), creator, labelMap.get(e.getKey()), minAreaPixels, minHoleAreaPixels, doSplit, regionRequest.getPlane()).stream()).collect(Collectors.toList()));
    }
    pathObjects.sort(DefaultPathObjectComparator.getInstance());
    return pathObjects;
}
Also used : ImageServer(qupath.lib.images.servers.ImageServer) Arrays(java.util.Arrays) PathClassTools(qupath.lib.objects.classes.PathClassTools) LoggerFactory(org.slf4j.LoggerFactory) PathClassFactory(qupath.lib.objects.classes.PathClassFactory) PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) ChannelThreshold(qupath.lib.analysis.images.ContourTracing.ChannelThreshold) Function(java.util.function.Function) ArrayList(java.util.ArrayList) ClassifierFunction(qupath.opencv.ml.pixel.PixelClassifiers.ClassifierFunction) HashSet(java.util.HashSet) LinkedHashMap(java.util.LinkedHashMap) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) Map(java.util.Map) Reclassifier(qupath.lib.objects.classes.Reclassifier) GeometryTools(qupath.lib.roi.GeometryTools) ImageData(qupath.lib.images.ImageData) Logger(org.slf4j.Logger) RegionRequest(qupath.lib.regions.RegionRequest) BufferedImage(java.awt.image.BufferedImage) PathObjects(qupath.lib.objects.PathObjects) Collection(java.util.Collection) PathClass(qupath.lib.objects.classes.PathClass) Set(java.util.Set) DefaultPathObjectComparator(qupath.lib.objects.DefaultPathObjectComparator) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) PathObjectTools(qupath.lib.objects.PathObjectTools) PathObject(qupath.lib.objects.PathObject) ROI(qupath.lib.roi.interfaces.ROI) List(java.util.List) PixelClassifier(qupath.lib.classifiers.pixel.PixelClassifier) PixelClassificationImageServer(qupath.lib.classifiers.pixel.PixelClassificationImageServer) ColorModel(java.awt.image.ColorModel) ContourTracing(qupath.lib.analysis.images.ContourTracing) ImagePlane(qupath.lib.regions.ImagePlane) Geometry(org.locationtech.jts.geom.Geometry) Comparator(java.util.Comparator) Collections(java.util.Collections) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) DataBuffer(java.awt.image.DataBuffer) Geometry(org.locationtech.jts.geom.Geometry) ArrayList(java.util.ArrayList) ChannelThreshold(qupath.lib.analysis.images.ContourTracing.ChannelThreshold) RegionRequest(qupath.lib.regions.RegionRequest)

Aggregations

PathClass (qupath.lib.objects.classes.PathClass)66 ArrayList (java.util.ArrayList)42 PathObject (qupath.lib.objects.PathObject)34 List (java.util.List)29 Map (java.util.Map)25 IOException (java.io.IOException)21 Logger (org.slf4j.Logger)20 LoggerFactory (org.slf4j.LoggerFactory)20 Collections (java.util.Collections)17 Collectors (java.util.stream.Collectors)17 BufferedImage (java.awt.image.BufferedImage)16 LinkedHashMap (java.util.LinkedHashMap)16 ROI (qupath.lib.roi.interfaces.ROI)16 HashMap (java.util.HashMap)15 ImageData (qupath.lib.images.ImageData)15 PathClassFactory (qupath.lib.objects.classes.PathClassFactory)15 PathObjectHierarchy (qupath.lib.objects.hierarchy.PathObjectHierarchy)15 ParameterList (qupath.lib.plugins.parameters.ParameterList)15 Collection (java.util.Collection)14 TreeMap (java.util.TreeMap)11