Search in sources :

Example 1 with PixelClassifierMetadata

use of qupath.lib.classifiers.pixel.PixelClassifierMetadata in project qupath by qupath.

the class PixelClassifierPane method doClassification.

private void doClassification() {
    // if (helper == null || helper.getFeatureServer() == null) {
    // //			updateFeatureCalculator();
    // //			updateClassifier();
    // if (helper == null) {
    // logger.error("No pixel classifier helper available!");
    // return;
    // }
    // }
    var imageData = qupath.getImageData();
    if (imageData == null) {
        if (!qupath.getViewers().stream().anyMatch(v -> v.getImageData() != null)) {
            logger.debug("doClassification() called, but no images are open");
            return;
        }
    }
    var model = selectedClassifier.get();
    if (model == null) {
        Dialogs.showErrorNotification("Pixel classifier", "No classifier selected!");
        return;
    }
    ClassifierTrainingData trainingData;
    try {
        var trainingImages = getTrainingImageData();
        if (trainingImages.size() > 1)
            logger.info("Creating training data from {} images", trainingImages.size());
        trainingData = helper.createTrainingData(trainingImages);
    } catch (Exception e) {
        logger.error("Error when updating training data", e);
        return;
    }
    if (trainingData == null) {
        resetPieChart();
        return;
    }
    // TODO: Optionally limit the number of training samples we use
    // var trainData = classifier.createTrainData(matFeatures, matTargets);
    // Ensure we seed the RNG for reproducibility
    opencv_core.setRNGSeed(rngSeed);
    // TODO: Prevent training K nearest neighbor with a huge number of samples (very slow!)
    var actualMaxSamples = this.maxSamples;
    var trainData = trainingData.getTrainData();
    if (actualMaxSamples > 0 && trainData.getNTrainSamples() > actualMaxSamples)
        trainData.setTrainTestSplit(actualMaxSamples, true);
    else
        trainData.shuffleTrainTest();
    // System.err.println("Train: " + trainData.getTrainResponses());
    // System.err.println("Test: " + trainData.getTestResponses());
    // Apply normalization, if we need to
    FeaturePreprocessor preprocessor = normalization.build(trainData.getTrainSamples(), false);
    if (preprocessor.doesSomething()) {
        preprocessingOp = ImageOps.ML.preprocessor(preprocessor);
    } else
        preprocessingOp = null;
    var labels = trainingData.getLabelMap();
    // Using getTrainNormCatResponses() causes confusion if classes are not represented
    // var targets = trainData.getTrainNormCatResponses();
    var targets = trainData.getTrainResponses();
    IntBuffer buffer = targets.createBuffer();
    int n = (int) targets.total();
    var rawCounts = new int[labels.size()];
    for (int i = 0; i < n; i++) {
        rawCounts[buffer.get(i)] += 1;
    }
    Map<PathClass, Integer> counts = new LinkedHashMap<>();
    for (var entry : labels.entrySet()) {
        counts.put(entry.getKey(), rawCounts[entry.getValue()]);
    }
    updatePieChart(counts);
    Mat weights = null;
    if (reweightSamples) {
        weights = new Mat(n, 1, opencv_core.CV_32FC1);
        FloatIndexer bufferWeights = weights.createIndexer();
        float[] weightArray = new float[rawCounts.length];
        for (int i = 0; i < weightArray.length; i++) {
            int c = rawCounts[i];
            // weightArray[i] = c == 0 ? 1 : (float)1.f/c;
            weightArray[i] = c == 0 ? 1 : (float) n / c;
        }
        for (int i = 0; i < n; i++) {
            int label = buffer.get(i);
            bufferWeights.put(i, weightArray[label]);
        }
        bufferWeights.release();
    }
    // Create TrainData in an appropriate format (e.g. labels or one-hot encoding)
    var trainSamples = trainData.getTrainSamples();
    var trainResponses = trainData.getTrainResponses();
    preprocessor.apply(trainSamples, false);
    trainData = model.createTrainData(trainSamples, trainResponses, weights, false);
    logger.info("Training data: {} x {}, Target data: {} x {}", trainSamples.rows(), trainSamples.cols(), trainResponses.rows(), trainResponses.cols());
    model.train(trainData);
    // Calculate accuracy using whatever we can, as a rough guide to progress
    var test = trainData.getTestSamples();
    String testSet = "HELD-OUT TRAINING SET";
    if (test.empty()) {
        test = trainSamples;
        testSet = "TRAINING SET";
    } else {
        preprocessor.apply(test, false);
        buffer = trainData.getTestNormCatResponses().createBuffer();
    }
    var testResults = new Mat();
    model.predict(test, testResults, null);
    IntBuffer bufferResults = testResults.createBuffer();
    int nTest = (int) testResults.rows();
    int nCorrect = 0;
    for (int i = 0; i < nTest; i++) {
        if (bufferResults.get(i) == buffer.get(i))
            nCorrect++;
    }
    logger.info("Current accuracy on the {}: {} %", testSet, GeneralTools.formatNumber(nCorrect * 100.0 / n, 1));
    if (model instanceof RTreesClassifier) {
        var trees = (RTreesClassifier) model;
        if (trees.hasFeatureImportance() && imageData != null)
            logVariableImportance(trees, helper.getFeatureOp().getChannels(imageData).stream().map(c -> c.getName()).collect(Collectors.toList()));
    }
    trainData.close();
    var featureCalculator = helper.getFeatureOp();
    if (preprocessingOp != null)
        featureCalculator = featureCalculator.appendOps(preprocessingOp);
    // TODO: CHECK IF INPUT SIZE SHOULD BE DEFINED
    int inputWidth = 512;
    int inputHeight = 512;
    // int inputWidth = featureCalculator.getInputSize().getWidth();
    // int inputHeight = featureCalculator.getInputSize().getHeight();
    var cal = helper.getResolution();
    var channelType = ImageServerMetadata.ChannelType.CLASSIFICATION;
    if (model.supportsProbabilities()) {
        channelType = selectedOutputType.get();
    }
    // Channels are needed for probability output (and work for classification as well)
    var labels2 = new TreeMap<Integer, PathClass>();
    for (var entry : labels.entrySet()) {
        var previous = labels2.put(entry.getValue(), entry.getKey());
        if (previous != null)
            logger.warn("Duplicate label found! {} matches with {} and {}, only the latter be used", entry.getValue(), previous, entry.getKey());
    }
    var channels = PathClassifierTools.classificationLabelsToChannels(labels2, true);
    PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputResolution(cal).inputShape(inputWidth, inputHeight).setChannelType(channelType).outputChannels(channels).build();
    currentClassifier.set(PixelClassifiers.createClassifier(model, featureCalculator, metadata, true));
    var overlay = PixelClassificationOverlay.create(qupath.getOverlayOptions(), currentClassifier.get(), getLivePredictionThreads());
    replaceOverlay(overlay);
}
Also used : Arrays(java.util.Arrays) IJTools(qupath.imagej.tools.IJTools) ParameterList(qupath.lib.plugins.parameters.ParameterList) Map(java.util.Map) Point2D(javafx.geometry.Point2D) IJExtension(qupath.imagej.gui.IJExtension) Platform(javafx.application.Platform) PieChart(javafx.scene.chart.PieChart) BooleanProperty(javafx.beans.property.BooleanProperty) Region(javafx.scene.layout.Region) ObservableList(javafx.collections.ObservableList) BorderPane(javafx.scene.layout.BorderPane) StringProperty(javafx.beans.property.StringProperty) CompositeImage(ij.CompositeImage) RectangleROI(qupath.lib.roi.RectangleROI) org.bytedeco.opencv.global.opencv_core(org.bytedeco.opencv.global.opencv_core) FXCollections(javafx.collections.FXCollections) Bindings(javafx.beans.binding.Bindings) IntegerProperty(javafx.beans.property.IntegerProperty) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) KNearest(org.bytedeco.opencv.opencv_ml.KNearest) ImageOps(qupath.opencv.ops.ImageOps) Slider(javafx.scene.control.Slider) Mat(org.bytedeco.opencv.opencv_core.Mat) TextAlignment(javafx.scene.text.TextAlignment) GridPane(javafx.scene.layout.GridPane) PathClassifierTools(qupath.lib.classifiers.PathClassifierTools) GeneralTools(qupath.lib.common.GeneralTools) RegionRequest(qupath.lib.regions.RegionRequest) IOException(java.io.IOException) ChartTools(qupath.lib.gui.charts.ChartTools) PixelClassifier(qupath.lib.classifiers.pixel.PixelClassifier) TreeMap(java.util.TreeMap) FloatIndexer(org.bytedeco.javacpp.indexer.FloatIndexer) SimpleObjectProperty(javafx.beans.property.SimpleObjectProperty) PixelCalibration(qupath.lib.images.servers.PixelCalibration) ObservableValue(javafx.beans.value.ObservableValue) PathObjectHierarchyListener(qupath.lib.objects.hierarchy.events.PathObjectHierarchyListener) ANN_MLP(org.bytedeco.opencv.opencv_ml.ANN_MLP) PathPrefs(qupath.lib.gui.prefs.PathPrefs) ImageServerMetadata(qupath.lib.images.servers.ImageServerMetadata) PaneTools(qupath.lib.gui.tools.PaneTools) EventHandler(javafx.event.EventHandler) Button(javafx.scene.control.Button) Pos(javafx.geometry.Pos) ImageServer(qupath.lib.images.servers.ImageServer) ListCell(javafx.scene.control.ListCell) LoggerFactory(org.slf4j.LoggerFactory) OverrunStyle(javafx.scene.control.OverrunStyle) RTreesClassifier(qupath.opencv.ml.OpenCVClassifiers.RTreesClassifier) OpenCVStatModel(qupath.opencv.ml.OpenCVClassifiers.OpenCVStatModel) Side(javafx.geometry.Side) PixelClassificationOverlay(qupath.lib.gui.viewer.overlays.PixelClassificationOverlay) ChannelType(qupath.lib.images.servers.ImageServerMetadata.ChannelType) ComboBox(javafx.scene.control.ComboBox) IntBuffer(java.nio.IntBuffer) PathObjectHierarchyEvent(qupath.lib.objects.hierarchy.events.PathObjectHierarchyEvent) QuPathGUI(qupath.lib.gui.QuPathGUI) BufferedImage(java.awt.image.BufferedImage) PixelClassifiers(qupath.opencv.ml.pixel.PixelClassifiers) Collection(java.util.Collection) Spinner(javafx.scene.control.Spinner) Collectors(java.util.stream.Collectors) FeaturePreprocessor(qupath.opencv.ml.FeaturePreprocessor) QuPathViewer(qupath.lib.gui.viewer.QuPathViewer) Priority(javafx.scene.layout.Priority) List(java.util.List) ToggleButton(javafx.scene.control.ToggleButton) GuiTools(qupath.lib.gui.tools.GuiTools) ColorToolsFX(qupath.lib.gui.tools.ColorToolsFX) IntStream(java.util.stream.IntStream) Scene(javafx.scene.Scene) ListView(javafx.scene.control.ListView) ReadOnlyObjectProperty(javafx.beans.property.ReadOnlyObjectProperty) SimpleStringProperty(javafx.beans.property.SimpleStringProperty) ButtonType(javafx.scene.control.ButtonType) MouseEvent(javafx.scene.input.MouseEvent) ProjectDialogs(qupath.lib.gui.dialogs.ProjectDialogs) Dialogs(qupath.lib.gui.dialogs.Dialogs) Insets(javafx.geometry.Insets) Normalization(qupath.lib.classifiers.Normalization) Callback(javafx.util.Callback) Tooltip(javafx.scene.control.Tooltip) WeakHashMap(java.util.WeakHashMap) ImageData(qupath.lib.images.ImageData) ObjectProperty(javafx.beans.property.ObjectProperty) Logger(org.slf4j.Logger) Label(javafx.scene.control.Label) ProjectImageEntry(qupath.lib.projects.ProjectImageEntry) PathClass(qupath.lib.objects.classes.PathClass) ImageOp(qupath.opencv.ops.ImageOp) PixelClassifierMetadata(qupath.lib.classifiers.pixel.PixelClassifierMetadata) OpenCVClassifiers(qupath.opencv.ml.OpenCVClassifiers) RTrees(org.bytedeco.opencv.opencv_ml.RTrees) LogisticRegression(org.bytedeco.opencv.opencv_ml.LogisticRegression) SimpleBooleanProperty(javafx.beans.property.SimpleBooleanProperty) Stage(javafx.stage.Stage) ClassifierTrainingData(qupath.process.gui.commands.ml.PixelClassifierTraining.ClassifierTrainingData) MiniViewers(qupath.lib.gui.commands.MiniViewers) Comparator(java.util.Comparator) ChangeListener(javafx.beans.value.ChangeListener) Collections(java.util.Collections) ContentDisplay(javafx.scene.control.ContentDisplay) Mat(org.bytedeco.opencv.opencv_core.Mat) ClassifierTrainingData(qupath.process.gui.commands.ml.PixelClassifierTraining.ClassifierTrainingData) FloatIndexer(org.bytedeco.javacpp.indexer.FloatIndexer) TreeMap(java.util.TreeMap) IOException(java.io.IOException) FeaturePreprocessor(qupath.opencv.ml.FeaturePreprocessor) RTreesClassifier(qupath.opencv.ml.OpenCVClassifiers.RTreesClassifier) LinkedHashMap(java.util.LinkedHashMap) PathClass(qupath.lib.objects.classes.PathClass) PixelClassifierMetadata(qupath.lib.classifiers.pixel.PixelClassifierMetadata) IntBuffer(java.nio.IntBuffer)

Aggregations

CompositeImage (ij.CompositeImage)1 BufferedImage (java.awt.image.BufferedImage)1 IOException (java.io.IOException)1 IntBuffer (java.nio.IntBuffer)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collection (java.util.Collection)1 Collections (java.util.Collections)1 Comparator (java.util.Comparator)1 LinkedHashMap (java.util.LinkedHashMap)1 List (java.util.List)1 Map (java.util.Map)1 TreeMap (java.util.TreeMap)1 WeakHashMap (java.util.WeakHashMap)1 Collectors (java.util.stream.Collectors)1 IntStream (java.util.stream.IntStream)1 Platform (javafx.application.Platform)1 Bindings (javafx.beans.binding.Bindings)1 BooleanProperty (javafx.beans.property.BooleanProperty)1 IntegerProperty (javafx.beans.property.IntegerProperty)1