Search in sources :

Example 1 with Reclassifier

use of qupath.lib.objects.classes.Reclassifier in project qupath by qupath.

the class OpenCVMLClassifier method classifyObjects.

static <T> int classifyObjects(FeatureExtractor<T> featureExtractor, OpenCVStatModel classifier, List<PathClass> pathClasses, ImageData<T> imageData, Collection<? extends PathObject> pathObjects, boolean resetExistingClass, boolean requestProbabilityEstimate) {
    if (featureExtractor == null) {
        logger.warn("No feature extractor! Cannot classify {} objects", pathObjects.size());
        return 0;
    }
    int counter = 0;
    List<Reclassifier> reclassifiers = new ArrayList<>();
    // Try not to have more than ~10 million entries per list
    int subListSize = (int) Math.max(1, Math.min(pathObjects.size(), (1024 * 1024 * 10 / featureExtractor.nFeatures())));
    Mat samples = new Mat();
    Mat results = new Mat();
    Mat probabilities = requestProbabilityEstimate ? new Mat() : null;
    // Work through the objects in chunks
    long startTime = System.currentTimeMillis();
    long lastTime = startTime;
    int nComplete = 0;
    for (var tempObjectList : Lists.partition(new ArrayList<>(pathObjects), subListSize)) {
        if (Thread.interrupted()) {
            logger.warn("Classification interrupted - will not be applied");
            return 0;
        }
        samples.create(tempObjectList.size(), featureExtractor.nFeatures(), opencv_core.CV_32FC1);
        FloatBuffer buffer = samples.createBuffer();
        featureExtractor.extractFeatures(imageData, tempObjectList, buffer);
        // Possibly log time taken
        nComplete += tempObjectList.size();
        long intermediateTime = System.currentTimeMillis();
        if (intermediateTime - lastTime > 1000L) {
            logger.debug("Calculated features for {}/{} objects in {} ms ({} ms per object, {}% complete)", nComplete, pathObjects.size(), (intermediateTime - startTime), GeneralTools.formatNumber((intermediateTime - startTime) / (double) nComplete, 2), GeneralTools.formatNumber(nComplete * 100.0 / pathObjects.size(), 1));
            lastTime = startTime;
        }
        boolean doMulticlass = classifier.supportsMulticlass();
        double threshold = 0.5;
        try {
            classifier.predict(samples, results, probabilities);
            IntIndexer idxResults = results.createIndexer();
            FloatIndexer idxProbabilities = null;
            if (probabilities != null && !probabilities.empty())
                idxProbabilities = probabilities.createIndexer();
            if (doMulticlass && idxProbabilities != null) {
                // Use probabilities if we require multiclass outputs
                long row = 0;
                // Previously .cols()
                int nCols = (int) idxProbabilities.size(2);
                List<String> classifications = new ArrayList<>();
                for (var pathObject : tempObjectList) {
                    classifications.clear();
                    for (int col = 0; col < nCols; col++) {
                        double prob = idxProbabilities.get(row, col);
                        if (prob >= threshold) {
                            var pathClass = col >= pathClasses.size() ? null : pathClasses.get(col);
                            if (pathClass != null)
                                classifications.add(pathClass.getName());
                        }
                    }
                    var pathClass = PathClassFactory.getPathClass(classifications);
                    if (PathClassTools.isIgnoredClass(pathClass)) {
                        pathClass = null;
                    }
                    if (!resetExistingClass) {
                        pathClass = PathClassTools.mergeClasses(pathObject.getPathClass(), pathClass);
                    }
                    reclassifiers.add(new Reclassifier(pathObject, pathClass, false));
                    row++;
                }
            } else {
                // Use results (indexed values) if we do not require multiclass outputs
                long row = 0;
                for (var pathObject : tempObjectList) {
                    int prediction = idxResults.get(row);
                    var pathClass = pathClasses.get(prediction);
                    double probability = idxProbabilities == null ? Double.NaN : idxProbabilities.get(row, prediction);
                    if (PathClassTools.isIgnoredClass(pathClass)) {
                        pathClass = null;
                        probability = Double.NaN;
                    }
                    if (!resetExistingClass) {
                        pathClass = PathClassTools.mergeClasses(pathObject.getPathClass(), pathClass);
                        probability = Double.NaN;
                    }
                    reclassifiers.add(new Reclassifier(pathObject, pathClass, true, probability));
                    row++;
                }
            }
            idxResults.release();
            if (idxProbabilities != null)
                idxProbabilities.release();
        } catch (Exception e) {
            logger.warn("Error with samples: {}", samples);
            logger.error(e.getLocalizedMessage(), e);
        }
        counter += tempObjectList.size();
    }
    long predictTime = System.currentTimeMillis() - startTime;
    logger.info("Prediction time: {} ms for {} objects ({} ns per object)", predictTime, pathObjects.size(), GeneralTools.formatNumber((double) predictTime / pathObjects.size() * 1000.0, 2));
    samples.close();
    results.close();
    if (probabilities != null)
        probabilities.close();
    // Apply classifications now
    reclassifiers.stream().forEach(p -> p.apply());
    return counter;
}
Also used : Mat(org.bytedeco.opencv.opencv_core.Mat) ArrayList(java.util.ArrayList) FloatBuffer(java.nio.FloatBuffer) FloatIndexer(org.bytedeco.javacpp.indexer.FloatIndexer) Reclassifier(qupath.lib.objects.classes.Reclassifier) IntIndexer(org.bytedeco.javacpp.indexer.IntIndexer)

Example 2 with Reclassifier

use of qupath.lib.objects.classes.Reclassifier in project qupath by qupath.

the class AbstractPathROITool method commitObjectToHierarchy.

/**
 * When drawing an object is complete, add it to the hierarchy - or whatever else is required.
 *
 * @param e
 * @param pathObject
 */
void commitObjectToHierarchy(MouseEvent e, PathObject pathObject) {
    if (pathObject == null)
        return;
    var viewer = getViewer();
    PathObjectHierarchy hierarchy = viewer.getHierarchy();
    var currentROI = pathObject.getROI();
    // If we are in selection mode, try to get objects to select
    if (PathPrefs.selectionModeProperty().get()) {
        var pathClass = PathPrefs.autoSetAnnotationClassProperty().get();
        var toSelect = hierarchy.getObjectsForROI(null, currentROI);
        if (!toSelect.isEmpty() && pathClass != null) {
            boolean retainIntensityClass = !(PathClassTools.isPositiveOrGradedIntensityClass(pathClass) || PathClassTools.isNegativeClass(pathClass));
            var reclassified = toSelect.stream().filter(p -> p.getPathClass() != pathClass).map(p -> new Reclassifier(p, pathClass, retainIntensityClass)).filter(r -> r.apply()).map(r -> r.getPathObject()).collect(Collectors.toList());
            if (!reclassified.isEmpty()) {
                hierarchy.fireObjectClassificationsChangedEvent(this, reclassified);
            }
        }
        if (pathObject.getParent() != null)
            hierarchy.removeObject(pathObject, true);
        // viewer.getHierarchy().fireHierarchyChangedEvent(this);
        if (toSelect.isEmpty())
            viewer.setSelectedObject(null);
        else if (e.isShiftDown()) {
            hierarchy.getSelectionModel().deselectObject(pathObject);
            hierarchy.getSelectionModel().selectObjects(toSelect);
        } else
            hierarchy.getSelectionModel().setSelectedObjects(toSelect, null);
    } else {
        if (!requestParentClipping(e)) {
            if (currentROI.isEmpty()) {
                pathObject = null;
            } else
                // Ensure object is within the hierarchy
                hierarchy.addPathObject(pathObject);
        } else {
            ROI roiNew = refineROIByParent(pathObject.getROI());
            if (roiNew.isEmpty()) {
                hierarchy.removeObject(pathObject, true);
                pathObject = null;
            } else {
                ((PathAnnotationObject) pathObject).setROI(roiNew);
                hierarchy.addPathObjectBelowParent(getCurrentParent(), pathObject, true);
            }
        }
        if (pathObject != null)
            viewer.setSelectedObject(pathObject);
        else
            viewer.getHierarchy().getSelectionModel().clearSelection();
    }
    var editor = viewer.getROIEditor();
    editor.ensureHandlesUpdated();
    editor.resetActiveHandle();
    if (preferReturnToMove()) {
        var qupath = QuPathGUI.getInstance();
        if (qupath != null)
            qupath.setSelectedTool(PathTools.MOVE);
    }
}
Also used : Logger(org.slf4j.Logger) Point2D(java.awt.geom.Point2D) PathClassTools(qupath.lib.objects.classes.PathClassTools) MouseEvent(javafx.scene.input.MouseEvent) PathObjects(qupath.lib.objects.PathObjects) LoggerFactory(org.slf4j.LoggerFactory) PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) PolylineROI(qupath.lib.roi.PolylineROI) Collectors(java.util.stream.Collectors) PathAnnotationObject(qupath.lib.objects.PathAnnotationObject) PathObject(qupath.lib.objects.PathObject) Cursor(javafx.scene.Cursor) ROI(qupath.lib.roi.interfaces.ROI) RoiEditor(qupath.lib.roi.RoiEditor) ImagePlane(qupath.lib.regions.ImagePlane) Reclassifier(qupath.lib.objects.classes.Reclassifier) PolygonROI(qupath.lib.roi.PolygonROI) Collections(java.util.Collections) PathPrefs(qupath.lib.gui.prefs.PathPrefs) QuPathGUI(qupath.lib.gui.QuPathGUI) PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) PathAnnotationObject(qupath.lib.objects.PathAnnotationObject) PolylineROI(qupath.lib.roi.PolylineROI) ROI(qupath.lib.roi.interfaces.ROI) PolygonROI(qupath.lib.roi.PolygonROI) Reclassifier(qupath.lib.objects.classes.Reclassifier)

Example 3 with Reclassifier

use of qupath.lib.objects.classes.Reclassifier in project qupath by qupath.

the class PixelClassifierTools method classifyObjectsByCentroid.

/**
 * Apply classification from a server to a collection of objects.
 *
 * @param classifierServer an {@link ImageServer} with output type
 * @param pathObjects
 * @param preferNucleusROI
 */
public static void classifyObjectsByCentroid(ImageServer<BufferedImage> classifierServer, Collection<PathObject> pathObjects, boolean preferNucleusROI) {
    var labels = classifierServer.getMetadata().getClassificationLabels();
    var reclassifiers = pathObjects.parallelStream().map(p -> {
        try {
            var roi = PathObjectTools.getROI(p, preferNucleusROI);
            int x = (int) roi.getCentroidX();
            int y = (int) roi.getCentroidY();
            int ind = getClassification(classifierServer, x, y, roi.getZ(), roi.getT());
            return new Reclassifier(p, labels.getOrDefault(ind, null), false);
        } catch (Exception e) {
            return new Reclassifier(p, null, false);
        }
    }).collect(Collectors.toList());
    reclassifiers.parallelStream().forEach(r -> r.apply());
}
Also used : ImageServer(qupath.lib.images.servers.ImageServer) Arrays(java.util.Arrays) PathClassTools(qupath.lib.objects.classes.PathClassTools) LoggerFactory(org.slf4j.LoggerFactory) PathClassFactory(qupath.lib.objects.classes.PathClassFactory) PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) ChannelThreshold(qupath.lib.analysis.images.ContourTracing.ChannelThreshold) Function(java.util.function.Function) ArrayList(java.util.ArrayList) ClassifierFunction(qupath.opencv.ml.pixel.PixelClassifiers.ClassifierFunction) HashSet(java.util.HashSet) LinkedHashMap(java.util.LinkedHashMap) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) Map(java.util.Map) Reclassifier(qupath.lib.objects.classes.Reclassifier) GeometryTools(qupath.lib.roi.GeometryTools) ImageData(qupath.lib.images.ImageData) Logger(org.slf4j.Logger) RegionRequest(qupath.lib.regions.RegionRequest) BufferedImage(java.awt.image.BufferedImage) PathObjects(qupath.lib.objects.PathObjects) Collection(java.util.Collection) PathClass(qupath.lib.objects.classes.PathClass) Set(java.util.Set) DefaultPathObjectComparator(qupath.lib.objects.DefaultPathObjectComparator) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) PathObjectTools(qupath.lib.objects.PathObjectTools) PathObject(qupath.lib.objects.PathObject) ROI(qupath.lib.roi.interfaces.ROI) List(java.util.List) PixelClassifier(qupath.lib.classifiers.pixel.PixelClassifier) PixelClassificationImageServer(qupath.lib.classifiers.pixel.PixelClassificationImageServer) ColorModel(java.awt.image.ColorModel) ContourTracing(qupath.lib.analysis.images.ContourTracing) ImagePlane(qupath.lib.regions.ImagePlane) Geometry(org.locationtech.jts.geom.Geometry) Comparator(java.util.Comparator) Collections(java.util.Collections) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) DataBuffer(java.awt.image.DataBuffer) Reclassifier(qupath.lib.objects.classes.Reclassifier) IOException(java.io.IOException)

Aggregations

Reclassifier (qupath.lib.objects.classes.Reclassifier)3 ArrayList (java.util.ArrayList)2 Collections (java.util.Collections)2 Collectors (java.util.stream.Collectors)2 Logger (org.slf4j.Logger)2 LoggerFactory (org.slf4j.LoggerFactory)2 PathObject (qupath.lib.objects.PathObject)2 PathObjects (qupath.lib.objects.PathObjects)2 PathClassTools (qupath.lib.objects.classes.PathClassTools)2 PathObjectHierarchy (qupath.lib.objects.hierarchy.PathObjectHierarchy)2 ImagePlane (qupath.lib.regions.ImagePlane)2 ROI (qupath.lib.roi.interfaces.ROI)2 Point2D (java.awt.geom.Point2D)1 BufferedImage (java.awt.image.BufferedImage)1 ColorModel (java.awt.image.ColorModel)1 DataBuffer (java.awt.image.DataBuffer)1 IOException (java.io.IOException)1 FloatBuffer (java.nio.FloatBuffer)1 Arrays (java.util.Arrays)1 Collection (java.util.Collection)1