Search in sources :

Example 31 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class ClassifierBuilderPane method crossValidateAcrossImages.

private void crossValidateAcrossImages() {
    // Try to put the current image data information into the tempMap, which stores training data separated by image path
    updateRetainedObjectsMap();
    Map<String, Map<PathClass, List<PathObject>>> tempMap = new LinkedHashMap<>(retainedObjectsMap.getMap());
    Normalization normalization = (Normalization) paramsUpdate.getChoiceParameterValue("normalizationMethod");
    for (String key : tempMap.keySet()) {
        Map<PathClass, List<PathObject>> validationMap = tempMap.get(key);
        Map<PathClass, List<PathObject>> trainingMap = new LinkedHashMap<>();
        for (Entry<String, Map<PathClass, List<PathObject>>> entry : tempMap.entrySet()) {
            if (entry.getKey().equals(key))
                continue;
            for (Entry<PathClass, List<PathObject>> entry2 : entry.getValue().entrySet()) {
                if (trainingMap.containsKey(entry2.getKey())) {
                    trainingMap.get(entry2.getKey()).addAll(entry2.getValue());
                } else {
                    trainingMap.put(entry2.getKey(), new ArrayList<>(entry2.getValue()));
                }
            }
        }
        // Perform subsampling
        SplitType splitType = (SplitType) paramsUpdate.getChoiceParameterValue("splitType");
        double maxTrainingProportion = paramsUpdate.getIntParameterValue("maxTrainingPercent") / 100.;
        long seed = paramsUpdate.getIntParameterValue("randomSeed");
        trainingMap = PathClassificationLabellingHelper.resampleClassificationMap(trainingMap, splitType, maxTrainingProportion, seed);
        // Get the current classifier - unfortunately, there's no easy way to duplicate/create a new one,
        // so we are left working with the 'live' classifier
        PathObjectClassifier classifier = (T) comboClassifiers.getSelectionModel().getSelectedItem();
        classifier.updateClassifier(trainingMap, featurePanel.getSelectedFeatures(), normalization);
        int nCorrect = 0;
        int nTested = 0;
        for (Entry<PathClass, List<PathObject>> entryValidation : validationMap.entrySet()) {
            classifier.classifyPathObjects(entryValidation.getValue());
            for (PathObject temp : entryValidation.getValue()) {
                if (entryValidation.getKey().equals(temp.getPathClass()))
                    nCorrect++;
                nTested++;
            }
        }
        double percent = nCorrect * 100.0 / nTested;
        logger.info(String.format("Percentage correct for %s: %.2f%%", key, percent));
        System.err.println(String.format("Percentage correct for %s: %.2f%% (%d/%d)", key, percent, nCorrect, nTested));
    }
    // Force a normal classifier update, to compensate for the fact we had to modify the 'live' classifier
    updateClassification(false);
}
Also used : PathObjectClassifier(qupath.lib.classifiers.PathObjectClassifier) LinkedHashMap(java.util.LinkedHashMap) PathClass(qupath.lib.objects.classes.PathClass) PathObject(qupath.lib.objects.PathObject) SplitType(qupath.process.gui.ml.legacy.PathClassificationLabellingHelper.SplitType) Normalization(qupath.lib.classifiers.Normalization) ParameterList(qupath.lib.plugins.parameters.ParameterList) List(java.util.List) ArrayList(java.util.ArrayList) Map(java.util.Map) LinkedHashMap(java.util.LinkedHashMap) TreeMap(java.util.TreeMap)

Example 32 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class ClassifierBuilderPane method updateClassification.

private synchronized void updateClassification(boolean interactive) {
    PathObjectHierarchy hierarchy = getHierarchy();
    if (hierarchy == null) {
        if (interactive)
            Dialogs.showErrorMessage("Classification error", "No objects available to classify!");
        btnSaveClassifier.setDisable(!classifier.isValid());
        return;
    }
    List<String> features = featurePanel.getSelectedFeatures();
    // If we've no features, default to trying to get
    if (features.isEmpty() && interactive) {
        selectAllFeatures();
        features = featurePanel.getSelectedFeatures();
        if (features.size() == 1)
            Dialogs.showInfoNotification("Feature selection", "Classifier set to train using the only available feature");
        else if (!features.isEmpty())
            Dialogs.showInfoNotification("Feature selection", "Classifier set to train using all " + features.size() + " available features");
    }
    // If still got no features, we're rather stuck
    if (features.isEmpty()) {
        Dialogs.showErrorMessage("Classification error", "No features available to use for classification!");
        btnSaveClassifier.setDisable(classifier == null || !classifier.isValid());
        return;
    }
    updatingClassification = true;
    // Get training map
    double maxTrainingProportion = paramsUpdate.getIntParameterValue("maxTrainingPercent") / 100.;
    long seed = paramsUpdate.getIntParameterValue("randomSeed");
    SplitType splitType = (SplitType) paramsUpdate.getChoiceParameterValue("splitType");
    Map<PathClass, List<PathObject>> map = getTrainingMap();
    // Apply limit if needed
    if (paramsUpdate.getBooleanParameterValue("limitTrainingToRepresentedClasses")) {
        Set<PathClass> representedClasses = map.keySet();
        for (List<PathObject> values : map.values()) {
            Iterator<PathObject> iter = values.iterator();
            while (iter.hasNext()) {
                PathClass pathClass = iter.next().getPathClass();
                if (pathClass != null && !representedClasses.contains(pathClass.getBaseClass()))
                    iter.remove();
            }
        }
    }
    // TODO: The order of entries in the map is not necessarily consistent (e.g. when a new annotation is added to the hierarchy -
    // irrespective of whether or not it has a classification).  Consequently, classifiers that rely on 'randomness' (e.g. random forests...)
    // can give different results for the same training data.  With 'auto-update' selected, this looks somewhat disturbing...
    Map<PathClass, List<PathObject>> mapTraining = PathClassificationLabellingHelper.resampleClassificationMap(map, splitType, maxTrainingProportion, seed);
    if (mapTraining.size() <= 1) {
        logger.error("Training samples from at least two different classes required to train a classifier!");
        updatingClassification = false;
        return;
    }
    // Try to create a separate test map, if we can
    Map<PathClass, List<PathObject>> mapTest = map;
    boolean testOnTrainingData = true;
    if (mapTraining != map) {
        for (Entry<PathClass, List<PathObject>> entry : mapTraining.entrySet()) {
            mapTest.get(entry.getKey()).removeAll(entry.getValue());
            logger.info("Number of training samples for " + entry.getKey() + ": " + entry.getValue().size());
        }
        testOnTrainingData = false;
    }
    // Balance the classes for training, if necessary
    if (paramsUpdate.getBooleanParameterValue("balanceClasses")) {
        logger.debug("Balancing classes...");
        int maxSize = -1;
        for (List<PathObject> temp : mapTraining.values()) maxSize = Math.max(maxSize, temp.size());
        Random random = new Random(seed);
        for (PathClass key : mapTraining.keySet()) {
            List<PathObject> temp = mapTraining.get(key);
            int size = temp.size();
            if (maxSize == size)
                continue;
            // Ensure a copy is made
            List<PathObject> list = new ArrayList<>(temp);
            for (int i = 0; i < maxSize - size; i++) {
                list.add(temp.get(random.nextInt(size)));
            }
            mapTraining.put(key, list);
        }
    }
    BackgroundClassificationTask task = new BackgroundClassificationTask(hierarchy, features, mapTraining, mapTest, testOnTrainingData);
    qupath.submitShortTask(task);
// doClassification(hierarchy, features, mapTraining, mapTest, testOnTrainingData);
}
Also used : PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) ArrayList(java.util.ArrayList) PathClass(qupath.lib.objects.classes.PathClass) PathObject(qupath.lib.objects.PathObject) Random(java.util.Random) SplitType(qupath.process.gui.ml.legacy.PathClassificationLabellingHelper.SplitType) ParameterList(qupath.lib.plugins.parameters.ParameterList) List(java.util.List) ArrayList(java.util.ArrayList)

Example 33 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class ClassifierBuilderPane method doClassification.

private void doClassification(final PathObjectHierarchy hierarchy, final List<String> features, final Map<PathClass, List<PathObject>> mapTraining, final Map<PathClass, List<PathObject>> mapTest, final boolean testOnTrainingData) {
    if (!Platform.isFxApplicationThread())
        Platform.runLater(() -> {
            progressIndicator.setProgress(-1);
            progressIndicator.setPrefSize(30, 30);
            progressIndicator.setVisible(true);
        });
    long startTime = System.currentTimeMillis();
    // Train classifier with requested normalization
    Normalization normalization = (Normalization) paramsUpdate.getChoiceParameterValue("normalizationMethod");
    String errorMessage = null;
    boolean classifierChanged = classifier != lastClassifierCompleted;
    try {
        classifierChanged = classifier.updateClassifier(mapTraining, features, normalization) || classifierChanged;
    } catch (Exception e) {
        errorMessage = "Classifier training failed with message:\n" + e.getLocalizedMessage() + "\nPlease try again with different settings.";
        e.printStackTrace();
    }
    if (classifier == null || !classifier.isValid()) {
        updateClassifierSummary(errorMessage);
        logger.error("Classifier is invalid!");
        updatingClassification = false;
        btnSaveClassifier.setDisable(classifier == null || !classifier.isValid());
        return;
    }
    long middleTime = System.currentTimeMillis();
    logger.info(String.format("Classifier training time: %.2f seconds", (middleTime - startTime) / 1000.));
    // Create an intensity classifier, if required
    PathObjectClassifier intensityClassifier = panelIntensities.getIntensityClassifier();
    // Apply classifier to everything
    Collection<PathObject> pathObjectsOrig = hierarchy.getDetectionObjects();
    int nClassified = 0;
    // Possible get proxy objects, depending on the thread we're on
    Collection<PathObject> pathObjects;
    if (Platform.isFxApplicationThread())
        pathObjects = pathObjectsOrig;
    else
        pathObjects = pathObjectsOrig.stream().map(p -> new PathObjectClassificationProxy(p)).collect(Collectors.toList());
    // Omit any objects that have already been classified by anything other than any of the target classes
    if (paramsUpdate.getBooleanParameterValue("limitTrainingToRepresentedClasses")) {
        Iterator<PathObject> iterator = pathObjects.iterator();
        Set<PathClass> representedClasses = mapTraining.keySet();
        while (iterator.hasNext()) {
            PathClass currentClass = iterator.next().getPathClass();
            if (currentClass != null && !representedClasses.contains(currentClass.getBaseClass()))
                iterator.remove();
        }
    }
    // In the event that we're using retained images, ensure we classify everything in our test set
    if (retainedObjectsMap.size() > 1 && !mapTest.isEmpty()) {
        for (Entry<PathClass, List<PathObject>> entry : mapTest.entrySet()) {
            pathObjects.addAll(entry.getValue());
        }
    }
    if (classifierChanged || hierarchyChanged) {
        nClassified = classifier.classifyPathObjects(pathObjects);
    } else {
        logger.info("Main classifier unchanged...");
    }
    if (intensityClassifier != null)
        intensityClassifier.classifyPathObjects(pathObjects);
    // }
    if (nClassified > 0) {
    // qupath.getViewer().repaint();
    } else if (classifierChanged || hierarchyChanged)
        logger.error("Classification failed - no objects classified!");
    long endTime = System.currentTimeMillis();
    logger.info(String.format("Classification time: %.2f seconds", (endTime - middleTime) / 1000.));
    // panelClassifier.setCursor(cursor);
    completeClassification(hierarchy, pathObjects, pathObjectsOrig, mapTest, testOnTrainingData);
}
Also used : Button(javafx.scene.control.Button) RunSavedClassifierWorkflowStep(qupath.lib.plugins.workflow.RunSavedClassifierWorkflowStep) Arrays(java.util.Arrays) BufferedInputStream(java.io.BufferedInputStream) ListCell(javafx.scene.control.ListCell) HierarchyEventType(qupath.lib.objects.hierarchy.events.PathObjectHierarchyEvent.HierarchyEventType) ObjectInputStream(java.io.ObjectInputStream) LoggerFactory(org.slf4j.LoggerFactory) Random(java.util.Random) VBox(javafx.scene.layout.VBox) Task(javafx.concurrent.Task) ParameterList(qupath.lib.plugins.parameters.ParameterList) ReadOnlyObjectWrapper(javafx.beans.property.ReadOnlyObjectWrapper) ComboBox(javafx.scene.control.ComboBox) ContextMenu(javafx.scene.control.ContextMenu) PathObjectHierarchyEvent(qupath.lib.objects.hierarchy.events.PathObjectHierarchyEvent) Map(java.util.Map) StandardPathClasses(qupath.lib.objects.classes.PathClassFactory.StandardPathClasses) Parameterizable(qupath.lib.plugins.parameters.Parameterizable) TableView(javafx.scene.control.TableView) QuPathGUI(qupath.lib.gui.QuPathGUI) Pane(javafx.scene.layout.Pane) PrintWriter(java.io.PrintWriter) MenuItem(javafx.scene.control.MenuItem) BufferedImage(java.awt.image.BufferedImage) Collection(java.util.Collection) Set(java.util.Set) Collectors(java.util.stream.Collectors) FileNotFoundException(java.io.FileNotFoundException) PathAnnotationObject(qupath.lib.objects.PathAnnotationObject) PathObject(qupath.lib.objects.PathObject) Platform(javafx.application.Platform) SeparatorMenuItem(javafx.scene.control.SeparatorMenuItem) Priority(javafx.scene.layout.Priority) List(java.util.List) Project(qupath.lib.projects.Project) ToggleButton(javafx.scene.control.ToggleButton) PathObjectClassifier(qupath.lib.classifiers.PathObjectClassifier) Entry(java.util.Map.Entry) BorderPane(javafx.scene.layout.BorderPane) Scene(javafx.scene.Scene) ListView(javafx.scene.control.ListView) TextArea(javafx.scene.control.TextArea) PathClassFactory(qupath.lib.objects.classes.PathClassFactory) PathObjectHierarchy(qupath.lib.objects.hierarchy.PathObjectHierarchy) TreeSet(java.util.TreeSet) BufferedOutputStream(java.io.BufferedOutputStream) ArrayList(java.util.ArrayList) TableColumn(javafx.scene.control.TableColumn) LinkedHashMap(java.util.LinkedHashMap) Dialogs(qupath.lib.gui.dialogs.Dialogs) Insets(javafx.geometry.Insets) ProgressBar(javafx.scene.control.ProgressBar) Normalization(qupath.lib.classifiers.Normalization) ObjectOutputStream(java.io.ObjectOutputStream) SplitType(qupath.process.gui.ml.legacy.PathClassificationLabellingHelper.SplitType) Callback(javafx.util.Callback) Tooltip(javafx.scene.control.Tooltip) GridPane(javafx.scene.layout.GridPane) ImageData(qupath.lib.images.ImageData) ProgressIndicator(javafx.scene.control.ProgressIndicator) PathClassifierTools(qupath.lib.classifiers.PathClassifierTools) Modality(javafx.stage.Modality) Logger(org.slf4j.Logger) Label(javafx.scene.control.Label) TitledPane(javafx.scene.control.TitledPane) Iterator(java.util.Iterator) ProjectImageEntry(qupath.lib.projects.ProjectImageEntry) PathClass(qupath.lib.objects.classes.PathClass) FileOutputStream(java.io.FileOutputStream) IOException(java.io.IOException) FileInputStream(java.io.FileInputStream) File(java.io.File) PathObjectTools(qupath.lib.objects.PathObjectTools) Cursor(javafx.scene.Cursor) ROI(qupath.lib.roi.interfaces.ROI) SelectionMode(javafx.scene.control.SelectionMode) TreeMap(java.util.TreeMap) Stage(javafx.stage.Stage) ParameterPanelFX(qupath.lib.gui.dialogs.ParameterPanelFX) ObservableValue(javafx.beans.value.ObservableValue) PathObjectHierarchyListener(qupath.lib.objects.hierarchy.events.PathObjectHierarchyListener) ChangeListener(javafx.beans.value.ChangeListener) Collections(java.util.Collections) PathObjectClassifier(qupath.lib.classifiers.PathObjectClassifier) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) PathClass(qupath.lib.objects.classes.PathClass) PathObject(qupath.lib.objects.PathObject) Normalization(qupath.lib.classifiers.Normalization) ParameterList(qupath.lib.plugins.parameters.ParameterList) List(java.util.List) ArrayList(java.util.ArrayList)

Example 34 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class CreateTrainingImageCommand method promptToCreateTrainingImage.

/**
 * Prompt to create a training image, based upon annotations throughout a project.
 * @param project
 * @param availableClasses
 * @return the entry of the new training image, created within the project
 */
public static ProjectImageEntry<BufferedImage> promptToCreateTrainingImage(Project<BufferedImage> project, List<PathClass> availableClasses) {
    if (project == null) {
        Dialogs.showErrorMessage(NAME, "You need a project!");
        return null;
    }
    if (availableClasses.isEmpty()) {
        Dialogs.showErrorMessage(NAME, "Please ensure classifications are available in QuPath!");
        return null;
    }
    List<PathClass> pathClasses = new ArrayList<>(availableClasses);
    if (!pathClasses.contains(pathClass))
        pathClass = pathClasses.get(0);
    var params = new ParameterList().addEmptyParameter("Generates a single image from regions extracted from the project.").addEmptyParameter("Before running this command, add classified rectangle annotations to select the regions.").addChoiceParameter("pathClass", "Classification", pathClass, pathClasses, "Select classification for annotated regions").addIntParameter("maxWidth", "Preferred image width", maxWidth, "px", "Preferred maximum width of the training image, in pixels").addBooleanParameter("doZ", "Do z-stacks", doZ, "Take all slices of a z-stack, where possible").addBooleanParameter("rectanglesOnly", "Rectangles only", rectanglesOnly, "Only extract regions annotated with rectangles. Otherwise, the bounding box of all regions with the classification will be taken.").addEmptyParameter("Note this command requires images to have similar bit-depths/channels/pixel sizes for compatibility.");
    if (!Dialogs.showParameterDialog(NAME, params))
        return null;
    pathClass = (PathClass) params.getChoiceParameterValue("pathClass");
    maxWidth = params.getIntParameterValue("maxWidth");
    doZ = params.getBooleanParameterValue("doZ");
    rectanglesOnly = params.getBooleanParameterValue("rectanglesOnly");
    var task = new Task<SparseImageServer>() {

        @Override
        protected SparseImageServer call() throws Exception {
            return createSparseServer(project, pathClass, maxWidth, doZ, rectanglesOnly);
        }
    };
    var dialog = new ProgressDialog(task);
    dialog.setTitle(NAME);
    dialog.setHeaderText("Creating training image...");
    Executors.newSingleThreadExecutor().submit(task);
    dialog.showAndWait();
    try {
        var server = task.get();
        // var server = createSparseServer(project, pathClass, maxWidth, doZ, rectanglesOnly);
        if (server == null || server.getManager().getRegions().isEmpty()) {
            Dialogs.showErrorMessage("Sparse image server", "No suitable annotations found in the current project!");
            return null;
        }
        var entry = ProjectCommands.addSingleImageToProject(project, server, null);
        server.close();
        project.syncChanges();
        return entry;
    } catch (Exception e) {
        Dialogs.showErrorMessage("Sparse image server", e);
        return null;
    }
}
Also used : PathClass(qupath.lib.objects.classes.PathClass) Task(javafx.concurrent.Task) ArrayList(java.util.ArrayList) ParameterList(qupath.lib.plugins.parameters.ParameterList) ProgressDialog(org.controlsfx.dialog.ProgressDialog) IOException(java.io.IOException)

Example 35 with PathClass

use of qupath.lib.objects.classes.PathClass in project qupath by qupath.

the class PixelClassifierPane method doClassification.

private void doClassification() {
    // if (helper == null || helper.getFeatureServer() == null) {
    // //			updateFeatureCalculator();
    // //			updateClassifier();
    // if (helper == null) {
    // logger.error("No pixel classifier helper available!");
    // return;
    // }
    // }
    var imageData = qupath.getImageData();
    if (imageData == null) {
        if (!qupath.getViewers().stream().anyMatch(v -> v.getImageData() != null)) {
            logger.debug("doClassification() called, but no images are open");
            return;
        }
    }
    var model = selectedClassifier.get();
    if (model == null) {
        Dialogs.showErrorNotification("Pixel classifier", "No classifier selected!");
        return;
    }
    ClassifierTrainingData trainingData;
    try {
        var trainingImages = getTrainingImageData();
        if (trainingImages.size() > 1)
            logger.info("Creating training data from {} images", trainingImages.size());
        trainingData = helper.createTrainingData(trainingImages);
    } catch (Exception e) {
        logger.error("Error when updating training data", e);
        return;
    }
    if (trainingData == null) {
        resetPieChart();
        return;
    }
    // TODO: Optionally limit the number of training samples we use
    // var trainData = classifier.createTrainData(matFeatures, matTargets);
    // Ensure we seed the RNG for reproducibility
    opencv_core.setRNGSeed(rngSeed);
    // TODO: Prevent training K nearest neighbor with a huge number of samples (very slow!)
    var actualMaxSamples = this.maxSamples;
    var trainData = trainingData.getTrainData();
    if (actualMaxSamples > 0 && trainData.getNTrainSamples() > actualMaxSamples)
        trainData.setTrainTestSplit(actualMaxSamples, true);
    else
        trainData.shuffleTrainTest();
    // System.err.println("Train: " + trainData.getTrainResponses());
    // System.err.println("Test: " + trainData.getTestResponses());
    // Apply normalization, if we need to
    FeaturePreprocessor preprocessor = normalization.build(trainData.getTrainSamples(), false);
    if (preprocessor.doesSomething()) {
        preprocessingOp = ImageOps.ML.preprocessor(preprocessor);
    } else
        preprocessingOp = null;
    var labels = trainingData.getLabelMap();
    // Using getTrainNormCatResponses() causes confusion if classes are not represented
    // var targets = trainData.getTrainNormCatResponses();
    var targets = trainData.getTrainResponses();
    IntBuffer buffer = targets.createBuffer();
    int n = (int) targets.total();
    var rawCounts = new int[labels.size()];
    for (int i = 0; i < n; i++) {
        rawCounts[buffer.get(i)] += 1;
    }
    Map<PathClass, Integer> counts = new LinkedHashMap<>();
    for (var entry : labels.entrySet()) {
        counts.put(entry.getKey(), rawCounts[entry.getValue()]);
    }
    updatePieChart(counts);
    Mat weights = null;
    if (reweightSamples) {
        weights = new Mat(n, 1, opencv_core.CV_32FC1);
        FloatIndexer bufferWeights = weights.createIndexer();
        float[] weightArray = new float[rawCounts.length];
        for (int i = 0; i < weightArray.length; i++) {
            int c = rawCounts[i];
            // weightArray[i] = c == 0 ? 1 : (float)1.f/c;
            weightArray[i] = c == 0 ? 1 : (float) n / c;
        }
        for (int i = 0; i < n; i++) {
            int label = buffer.get(i);
            bufferWeights.put(i, weightArray[label]);
        }
        bufferWeights.release();
    }
    // Create TrainData in an appropriate format (e.g. labels or one-hot encoding)
    var trainSamples = trainData.getTrainSamples();
    var trainResponses = trainData.getTrainResponses();
    preprocessor.apply(trainSamples, false);
    trainData = model.createTrainData(trainSamples, trainResponses, weights, false);
    logger.info("Training data: {} x {}, Target data: {} x {}", trainSamples.rows(), trainSamples.cols(), trainResponses.rows(), trainResponses.cols());
    model.train(trainData);
    // Calculate accuracy using whatever we can, as a rough guide to progress
    var test = trainData.getTestSamples();
    String testSet = "HELD-OUT TRAINING SET";
    if (test.empty()) {
        test = trainSamples;
        testSet = "TRAINING SET";
    } else {
        preprocessor.apply(test, false);
        buffer = trainData.getTestNormCatResponses().createBuffer();
    }
    var testResults = new Mat();
    model.predict(test, testResults, null);
    IntBuffer bufferResults = testResults.createBuffer();
    int nTest = (int) testResults.rows();
    int nCorrect = 0;
    for (int i = 0; i < nTest; i++) {
        if (bufferResults.get(i) == buffer.get(i))
            nCorrect++;
    }
    logger.info("Current accuracy on the {}: {} %", testSet, GeneralTools.formatNumber(nCorrect * 100.0 / n, 1));
    if (model instanceof RTreesClassifier) {
        var trees = (RTreesClassifier) model;
        if (trees.hasFeatureImportance() && imageData != null)
            logVariableImportance(trees, helper.getFeatureOp().getChannels(imageData).stream().map(c -> c.getName()).collect(Collectors.toList()));
    }
    trainData.close();
    var featureCalculator = helper.getFeatureOp();
    if (preprocessingOp != null)
        featureCalculator = featureCalculator.appendOps(preprocessingOp);
    // TODO: CHECK IF INPUT SIZE SHOULD BE DEFINED
    int inputWidth = 512;
    int inputHeight = 512;
    // int inputWidth = featureCalculator.getInputSize().getWidth();
    // int inputHeight = featureCalculator.getInputSize().getHeight();
    var cal = helper.getResolution();
    var channelType = ImageServerMetadata.ChannelType.CLASSIFICATION;
    if (model.supportsProbabilities()) {
        channelType = selectedOutputType.get();
    }
    // Channels are needed for probability output (and work for classification as well)
    var labels2 = new TreeMap<Integer, PathClass>();
    for (var entry : labels.entrySet()) {
        var previous = labels2.put(entry.getValue(), entry.getKey());
        if (previous != null)
            logger.warn("Duplicate label found! {} matches with {} and {}, only the latter be used", entry.getValue(), previous, entry.getKey());
    }
    var channels = PathClassifierTools.classificationLabelsToChannels(labels2, true);
    PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputResolution(cal).inputShape(inputWidth, inputHeight).setChannelType(channelType).outputChannels(channels).build();
    currentClassifier.set(PixelClassifiers.createClassifier(model, featureCalculator, metadata, true));
    var overlay = PixelClassificationOverlay.create(qupath.getOverlayOptions(), currentClassifier.get(), getLivePredictionThreads());
    replaceOverlay(overlay);
}
Also used : Arrays(java.util.Arrays) IJTools(qupath.imagej.tools.IJTools) ParameterList(qupath.lib.plugins.parameters.ParameterList) Map(java.util.Map) Point2D(javafx.geometry.Point2D) IJExtension(qupath.imagej.gui.IJExtension) Platform(javafx.application.Platform) PieChart(javafx.scene.chart.PieChart) BooleanProperty(javafx.beans.property.BooleanProperty) Region(javafx.scene.layout.Region) ObservableList(javafx.collections.ObservableList) BorderPane(javafx.scene.layout.BorderPane) StringProperty(javafx.beans.property.StringProperty) CompositeImage(ij.CompositeImage) RectangleROI(qupath.lib.roi.RectangleROI) org.bytedeco.opencv.global.opencv_core(org.bytedeco.opencv.global.opencv_core) FXCollections(javafx.collections.FXCollections) Bindings(javafx.beans.binding.Bindings) IntegerProperty(javafx.beans.property.IntegerProperty) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) KNearest(org.bytedeco.opencv.opencv_ml.KNearest) ImageOps(qupath.opencv.ops.ImageOps) Slider(javafx.scene.control.Slider) Mat(org.bytedeco.opencv.opencv_core.Mat) TextAlignment(javafx.scene.text.TextAlignment) GridPane(javafx.scene.layout.GridPane) PathClassifierTools(qupath.lib.classifiers.PathClassifierTools) GeneralTools(qupath.lib.common.GeneralTools) RegionRequest(qupath.lib.regions.RegionRequest) IOException(java.io.IOException) ChartTools(qupath.lib.gui.charts.ChartTools) PixelClassifier(qupath.lib.classifiers.pixel.PixelClassifier) TreeMap(java.util.TreeMap) FloatIndexer(org.bytedeco.javacpp.indexer.FloatIndexer) SimpleObjectProperty(javafx.beans.property.SimpleObjectProperty) PixelCalibration(qupath.lib.images.servers.PixelCalibration) ObservableValue(javafx.beans.value.ObservableValue) PathObjectHierarchyListener(qupath.lib.objects.hierarchy.events.PathObjectHierarchyListener) ANN_MLP(org.bytedeco.opencv.opencv_ml.ANN_MLP) PathPrefs(qupath.lib.gui.prefs.PathPrefs) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) PaneTools(qupath.lib.gui.tools.PaneTools) EventHandler(javafx.event.EventHandler) Button(javafx.scene.control.Button) Pos(javafx.geometry.Pos) ImageServer(qupath.lib.images.servers.ImageServer) ListCell(javafx.scene.control.ListCell) LoggerFactory(org.slf4j.LoggerFactory) OverrunStyle(javafx.scene.control.OverrunStyle) RTreesClassifier(qupath.opencv.ml.OpenCVClassifiers.RTreesClassifier) OpenCVStatModel(qupath.opencv.ml.OpenCVClassifiers.OpenCVStatModel) Side(javafx.geometry.Side) PixelClassificationOverlay(qupath.lib.gui.viewer.overlays.PixelClassificationOverlay) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) ComboBox(javafx.scene.control.ComboBox) IntBuffer(java.nio.IntBuffer) PathObjectHierarchyEvent(qupath.lib.objects.hierarchy.events.PathObjectHierarchyEvent) QuPathGUI(qupath.lib.gui.QuPathGUI) BufferedImage(java.awt.image.BufferedImage) PixelClassifiers(qupath.opencv.ml.pixel.PixelClassifiers) Collection(java.util.Collection) Spinner(javafx.scene.control.Spinner) Collectors(java.util.stream.Collectors) FeaturePreprocessor(qupath.opencv.ml.FeaturePreprocessor) QuPathViewer(qupath.lib.gui.viewer.QuPathViewer) Priority(javafx.scene.layout.Priority) List(java.util.List) ToggleButton(javafx.scene.control.ToggleButton) GuiTools(qupath.lib.gui.tools.GuiTools) ColorToolsFX(qupath.lib.gui.tools.ColorToolsFX) IntStream(java.util.stream.IntStream) Scene(javafx.scene.Scene) ListView(javafx.scene.control.ListView) ReadOnlyObjectProperty(javafx.beans.property.ReadOnlyObjectProperty) SimpleStringProperty(javafx.beans.property.SimpleStringProperty) ButtonType(javafx.scene.control.ButtonType) MouseEvent(javafx.scene.input.MouseEvent) ProjectDialogs(qupath.lib.gui.dialogs.ProjectDialogs) Dialogs(qupath.lib.gui.dialogs.Dialogs) Insets(javafx.geometry.Insets) Normalization(qupath.lib.classifiers.Normalization) Callback(javafx.util.Callback) Tooltip(javafx.scene.control.Tooltip) WeakHashMap(java.util.WeakHashMap) ImageData(qupath.lib.images.ImageData) ObjectProperty(javafx.beans.property.ObjectProperty) Logger(org.slf4j.Logger) Label(javafx.scene.control.Label) ProjectImageEntry(qupath.lib.projects.ProjectImageEntry) PathClass(qupath.lib.objects.classes.PathClass) ImageOp(qupath.opencv.ops.ImageOp) PixelClassifierMetadata(qupath.lib.classifiers.pixel.PixelClassifierMetadata) OpenCVClassifiers(qupath.opencv.ml.OpenCVClassifiers) RTrees(org.bytedeco.opencv.opencv_ml.RTrees) LogisticRegression(org.bytedeco.opencv.opencv_ml.LogisticRegression) SimpleBooleanProperty(javafx.beans.property.SimpleBooleanProperty) Stage(javafx.stage.Stage) ClassifierTrainingData(qupath.process.gui.commands.ml.PixelClassifierTraining.ClassifierTrainingData) MiniViewers(qupath.lib.gui.commands.MiniViewers) Comparator(java.util.Comparator) ChangeListener(javafx.beans.value.ChangeListener) Collections(java.util.Collections) ContentDisplay(javafx.scene.control.ContentDisplay) Mat(org.bytedeco.opencv.opencv_core.Mat) ClassifierTrainingData(qupath.process.gui.commands.ml.PixelClassifierTraining.ClassifierTrainingData) FloatIndexer(org.bytedeco.javacpp.indexer.FloatIndexer) TreeMap(java.util.TreeMap) IOException(java.io.IOException) FeaturePreprocessor(qupath.opencv.ml.FeaturePreprocessor) RTreesClassifier(qupath.opencv.ml.OpenCVClassifiers.RTreesClassifier) LinkedHashMap(java.util.LinkedHashMap) PathClass(qupath.lib.objects.classes.PathClass) PixelClassifierMetadata(qupath.lib.classifiers.pixel.PixelClassifierMetadata) IntBuffer(java.nio.IntBuffer)

Aggregations

PathClass (qupath.lib.objects.classes.PathClass)66 ArrayList (java.util.ArrayList)42 PathObject (qupath.lib.objects.PathObject)34 List (java.util.List)29 Map (java.util.Map)25 IOException (java.io.IOException)21 Logger (org.slf4j.Logger)20 LoggerFactory (org.slf4j.LoggerFactory)20 Collections (java.util.Collections)17 Collectors (java.util.stream.Collectors)17 BufferedImage (java.awt.image.BufferedImage)16 LinkedHashMap (java.util.LinkedHashMap)16 ROI (qupath.lib.roi.interfaces.ROI)16 HashMap (java.util.HashMap)15 ImageData (qupath.lib.images.ImageData)15 PathClassFactory (qupath.lib.objects.classes.PathClassFactory)15 PathObjectHierarchy (qupath.lib.objects.hierarchy.PathObjectHierarchy)15 ParameterList (qupath.lib.plugins.parameters.ParameterList)15 Collection (java.util.Collection)14 TreeMap (java.util.TreeMap)11