use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class ClassifierBuilderPane method updateRetainedObjectsMap.
/**
* Update the retained objects map using the data from the current image.
*/
private void updateRetainedObjectsMap() {
PathObjectHierarchy hierarchy = getHierarchy();
if (hierarchy != null) {
Map<PathClass, List<PathObject>> mapCurrent = PathClassificationLabellingHelper.getClassificationMap(hierarchy, paramsUpdate.getBooleanParameterValue("trainFromPoints"));
// Add in any retained objects, if we have some
PathClassificationLabellingHelper.countObjectsInMap(mapCurrent);
// int retainedImageCount = retainedObjectsMap.addToTrainingMap(map, getImageData().getServerPath());
retainedObjectsMap.put(getMapKey(getImageData()), mapCurrent);
updateRetainedObjectsLabel();
}
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class RetainedTrainingObjects method readExternal.
@SuppressWarnings("unchecked")
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// Version
Object version = in.readObject();
if (!("Training objects v0.1".equals(version)))
throw new IOException("Version not supported!");
Object data = in.readObject();
Map<String, Map<PathClass, List<PathObject>>> mapRead = (Map<String, Map<PathClass, List<PathObject>>>) data;
// Loop through and ensure that we are using the 'current' PathClasses - not the deserialized ones
for (String key : mapRead.keySet()) {
Map<PathClass, List<PathObject>> oldMap = mapRead.get(key);
Map<PathClass, List<PathObject>> newMap = new HashMap<>();
for (Entry<PathClass, List<PathObject>> entry : oldMap.entrySet()) {
PathClass pathClass = entry.getKey();
newMap.put(PathClassFactory.getPathClass(pathClass.getName(), pathClass.getColor()), entry.getValue());
}
retainedObjectsMap.put(key, newMap);
}
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class PixelClassificationMeasurementManager method updateMeasurements.
private synchronized MeasurementList updateMeasurements(Map<Integer, PathClass> classificationLabels, long[] counts, double pixelArea, String pixelAreaUnits) {
long total = counts == null ? 0L : GeneralTools.sum(counts);
Collection<PathClass> pathClasses = new LinkedHashSet<>(classificationLabels.values());
boolean addNames = measurementNames == null;
List<String> tempList = null;
int nMeasurements = pathClasses.size() * 2;
if (!isMulticlass)
nMeasurements += 2;
if (addNames) {
tempList = new ArrayList<>();
measurementNames = Collections.unmodifiableList(tempList);
} else
nMeasurements = measurementNames.size();
MeasurementList measurementList = MeasurementListFactory.createMeasurementList(nMeasurements, MeasurementListType.DOUBLE);
Set<PathClass> ignored = pathClasses.stream().filter(p -> p == null || PathClassTools.isIgnoredClass(p)).collect(Collectors.toSet());
// Calculate totals for all non-ignored classes
Map<PathClass, Long> pathClassTotals = new LinkedHashMap<>();
long totalWithoutIgnored = 0L;
if (counts != null) {
for (var entry : classificationLabels.entrySet()) {
PathClass pathClass = entry.getValue();
// Skip background channels
if (pathClass == null || ignored.contains(pathClass))
continue;
int c = entry.getKey();
long temp = counts == null || c >= counts.length ? 0L : counts[c];
totalWithoutIgnored += temp;
pathClassTotals.put(pathClass, pathClassTotals.getOrDefault(pathClass, 0L) + temp);
}
} else {
for (var pathClass : pathClasses) if (pathClass != null && !ignored.contains(pathClass))
pathClassTotals.put(pathClass, 0L);
}
// Add measurements for classes
for (var entry : pathClassTotals.entrySet()) {
var pathClass = entry.getKey();
String name = pathClass.toString();
String namePercentage = name + " %";
String nameArea = name + " area " + pixelAreaUnits;
if (tempList != null) {
if (pathClassTotals.size() > 1)
tempList.add(namePercentage);
tempList.add(nameArea);
}
if (counts != null) {
long count = entry.getValue();
if (pathClassTotals.size() > 1)
measurementList.putMeasurement(namePercentage, (double) count / totalWithoutIgnored * 100.0);
if (!Double.isNaN(pixelArea)) {
measurementList.putMeasurement(nameArea, count * pixelArea);
}
}
}
// Add total area (useful as a check)
String nameArea = "Total annotated area " + pixelAreaUnits;
String nameAreaWithoutIgnored = "Total quantified area " + pixelAreaUnits;
if (counts != null && !Double.isNaN(pixelArea)) {
if (tempList != null) {
tempList.add(nameArea);
tempList.add(nameAreaWithoutIgnored);
}
measurementList.putMeasurement(nameArea, totalWithoutIgnored * pixelArea);
measurementList.putMeasurement(nameAreaWithoutIgnored, total * pixelArea);
}
measurementList.close();
return measurementList;
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class PixelClassificationMeasurementManager method calculateMeasurements.
/**
* Calculate measurements for a specified ROI if possible.
*
* @param roi
* @param cachedOnly abort the mission if required tiles are not cached
* @return
*/
synchronized MeasurementList calculateMeasurements(final ROI roi, final boolean cachedOnly) {
Map<Integer, PathClass> classificationLabels = classifierServer.getMetadata().getClassificationLabels();
long[] counts = null;
// imageData.getServer();
ImageServer<BufferedImage> server = classifierServer;
// Check we have a suitable output type
ImageServerMetadata.ChannelType type = classifierServer.getMetadata().getChannelType();
if (type == ImageServerMetadata.ChannelType.FEATURE)
return null;
Shape shape = null;
if (!roi.isPoint())
shape = RoiTools.getShape(roi);
// Get the regions we need
Collection<TileRequest> requests;
// For the root, we want all tile requests
if (roi == rootROI) {
requests = server.getTileRequestManager().getAllTileRequests();
} else if (!roi.isEmpty()) {
var regionRequest = RegionRequest.createInstance(server.getPath(), requestedDownsample, roi);
requests = server.getTileRequestManager().getTileRequests(regionRequest);
} else
requests = Collections.emptyList();
if (requests.isEmpty()) {
logger.debug("Request empty for {}", roi);
return null;
}
// Try to get all cached tiles - if this fails, return quickly (can't calculate measurement)
Map<TileRequest, BufferedImage> localCache = new HashMap<>();
for (TileRequest request : requests) {
BufferedImage tile = null;
try {
tile = cachedOnly ? classifierServer.getCachedTile(request) : classifierServer.readBufferedImage(request.getRegionRequest());
} catch (IOException e) {
logger.error("Error requesting tile " + request, e);
}
if (tile == null)
return null;
localCache.put(request, tile);
}
// Calculate stained proportions
BasicStroke stroke = null;
byte[] mask = null;
BufferedImage imgMask = imgTileMask.get();
for (Map.Entry<TileRequest, BufferedImage> entry : localCache.entrySet()) {
TileRequest region = entry.getKey();
BufferedImage tile = entry.getValue();
// Create a binary mask corresponding to the current tile
if (imgMask == null || imgMask.getWidth() < tile.getWidth() || imgMask.getHeight() < tile.getHeight() || imgMask.getType() != BufferedImage.TYPE_BYTE_GRAY) {
imgMask = new BufferedImage(tile.getWidth(), tile.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
imgTileMask.set(imgMask);
}
// Get the tile, which is needed for sub-pixel accuracy
if (roi.isLine() || roi.isArea()) {
Graphics2D g2d = imgMask.createGraphics();
g2d.setColor(Color.BLACK);
g2d.fillRect(0, 0, tile.getWidth(), tile.getHeight());
g2d.setColor(Color.WHITE);
g2d.scale(1.0 / region.getDownsample(), 1.0 / region.getDownsample());
g2d.translate(-region.getTileX() * region.getDownsample(), -region.getTileY() * region.getDownsample());
if (roi.isLine()) {
float fDownsample = (float) region.getDownsample();
if (stroke == null || stroke.getLineWidth() != fDownsample)
stroke = new BasicStroke((float) fDownsample);
g2d.setStroke(stroke);
g2d.draw(shape);
} else if (roi.isArea())
g2d.fill(shape);
g2d.dispose();
} else if (roi.isPoint()) {
for (var p : roi.getAllPoints()) {
int x = (int) ((p.getX() - region.getImageX()) / region.getDownsample());
int y = (int) ((p.getY() - region.getImageY()) / region.getDownsample());
if (x >= 0 && y >= 0 && x < imgMask.getWidth() && y < imgMask.getHeight())
imgMask.getRaster().setSample(x, y, 0, 255);
}
}
int h = tile.getHeight();
int w = tile.getWidth();
if (mask == null || mask.length != h * w)
mask = new byte[w * h];
int nChannels = tile.getSampleModel().getNumBands();
try {
switch(type) {
case CLASSIFICATION:
// Calculate histogram to get labelled image counts
counts = BufferedImageTools.computeUnsignedIntHistogram(tile.getRaster(), counts, imgMask.getRaster());
break;
case PROBABILITY:
// Take classification from the channel with the highest value
if (nChannels > 1) {
counts = BufferedImageTools.computeArgMaxHistogram(tile.getRaster(), counts, imgMask.getRaster());
break;
}
// For one channel, fall through & treat as multiclass
case MULTICLASS_PROBABILITY:
// For multiclass, count
if (counts == null)
counts = new long[nChannels];
double threshold = getProbabilityThreshold(tile.getRaster());
for (int c = 0; c < nChannels; c++) counts[c] += BufferedImageTools.computeAboveThresholdCounts(tile.getRaster(), c, threshold, imgMask.getRaster());
case DEFAULT:
case FEATURE:
default:
// TODO: Consider handling other OutputTypes?
return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
}
} catch (Exception e) {
logger.error("Error calculating classification areas", e);
if (nChannels > 1 && type == ChannelType.CLASSIFICATION)
logger.error("There are {} channels - are you sure this is really a classification image?", nChannels);
}
}
return updateMeasurements(classificationLabels, counts, pixelArea, pixelAreaUnits);
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class PixelClassifierTools method createObjectsFromPixelClassifier.
/**
* Create objects based upon an {@link ImageServer} that provides classification or probability output.
*
* @param server image server providing pixels from which objects should be created
* @param labels classification labels; if null, these will be taken from ImageServer#getMetadata() and all non-ignored classifications will be used.
* Providing a map makes it possible to explicitly exclude some classifications.
* @param roi region of interest in which objects should be created (optional; if null, the entire image is used)
* @param creator function to create an object from a ROI (e.g. annotation or detection)
* @param minArea minimum area for an object fragment to retain, in calibrated units based on the pixel calibration
* @param minHoleArea minimum area for a hole to fill, in calibrated units based on the pixel calibration
* @param doSplit if true, split connected regions into separate objects
* @return the objects created within the ROI
* @throws IOException
*/
public static Collection<PathObject> createObjectsFromPixelClassifier(ImageServer<BufferedImage> server, Map<Integer, PathClass> labels, ROI roi, Function<ROI, ? extends PathObject> creator, double minArea, double minHoleArea, boolean doSplit) throws IOException {
// We need classification labels to do anything
if (labels == null)
labels = parseClassificationLabels(server.getMetadata().getClassificationLabels(), false);
if (labels == null || labels.isEmpty())
throw new IllegalArgumentException("Cannot create objects for server - no classification labels are available!");
ChannelThreshold[] thresholds = labels.entrySet().stream().map(e -> ChannelThreshold.create(e.getKey())).toArray(ChannelThreshold[]::new);
if (roi != null && !roi.isArea()) {
logger.warn("Cannot create objects for non-area ROIs");
return Collections.emptyList();
}
Geometry clipArea = roi == null ? null : roi.getGeometry();
// Identify regions for selected ROI or entire image
// This is a list because it might need to handle multiple z-slices or timepoints
List<RegionRequest> regionRequests;
if (roi != null) {
var request = RegionRequest.createInstance(server.getPath(), server.getDownsampleForResolution(0), roi);
regionRequests = Collections.singletonList(request);
} else {
regionRequests = RegionRequest.createAllRequests(server, server.getDownsampleForResolution(0));
}
double pixelArea = server.getPixelCalibration().getPixelWidth().doubleValue() * server.getPixelCalibration().getPixelHeight().doubleValue();
double minAreaPixels = minArea / pixelArea;
double minHoleAreaPixels = minHoleArea / pixelArea;
// Create output array
var pathObjects = new ArrayList<PathObject>();
// Loop through region requests (usually 1, unless we have a z-stack or time series)
for (RegionRequest regionRequest : regionRequests) {
Map<Integer, Geometry> geometryMap = ContourTracing.traceGeometries(server, regionRequest, clipArea, thresholds);
var labelMap = labels;
pathObjects.addAll(geometryMap.entrySet().parallelStream().flatMap(e -> geometryToObjects(e.getValue(), creator, labelMap.get(e.getKey()), minAreaPixels, minHoleAreaPixels, doSplit, regionRequest.getPlane()).stream()).collect(Collectors.toList()));
}
pathObjects.sort(DefaultPathObjectComparator.getInstance());
return pathObjects;
}
Aggregations