use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class ClassifierBuilderPane method crossValidateAcrossImages.
private void crossValidateAcrossImages() {
// Try to put the current image data information into the tempMap, which stores training data separated by image path
updateRetainedObjectsMap();
Map<String, Map<PathClass, List<PathObject>>> tempMap = new LinkedHashMap<>(retainedObjectsMap.getMap());
Normalization normalization = (Normalization) paramsUpdate.getChoiceParameterValue("normalizationMethod");
for (String key : tempMap.keySet()) {
Map<PathClass, List<PathObject>> validationMap = tempMap.get(key);
Map<PathClass, List<PathObject>> trainingMap = new LinkedHashMap<>();
for (Entry<String, Map<PathClass, List<PathObject>>> entry : tempMap.entrySet()) {
if (entry.getKey().equals(key))
continue;
for (Entry<PathClass, List<PathObject>> entry2 : entry.getValue().entrySet()) {
if (trainingMap.containsKey(entry2.getKey())) {
trainingMap.get(entry2.getKey()).addAll(entry2.getValue());
} else {
trainingMap.put(entry2.getKey(), new ArrayList<>(entry2.getValue()));
}
}
}
// Perform subsampling
SplitType splitType = (SplitType) paramsUpdate.getChoiceParameterValue("splitType");
double maxTrainingProportion = paramsUpdate.getIntParameterValue("maxTrainingPercent") / 100.;
long seed = paramsUpdate.getIntParameterValue("randomSeed");
trainingMap = PathClassificationLabellingHelper.resampleClassificationMap(trainingMap, splitType, maxTrainingProportion, seed);
// Get the current classifier - unfortunately, there's no easy way to duplicate/create a new one,
// so we are left working with the 'live' classifier
PathObjectClassifier classifier = (T) comboClassifiers.getSelectionModel().getSelectedItem();
classifier.updateClassifier(trainingMap, featurePanel.getSelectedFeatures(), normalization);
int nCorrect = 0;
int nTested = 0;
for (Entry<PathClass, List<PathObject>> entryValidation : validationMap.entrySet()) {
classifier.classifyPathObjects(entryValidation.getValue());
for (PathObject temp : entryValidation.getValue()) {
if (entryValidation.getKey().equals(temp.getPathClass()))
nCorrect++;
nTested++;
}
}
double percent = nCorrect * 100.0 / nTested;
logger.info(String.format("Percentage correct for %s: %.2f%%", key, percent));
System.err.println(String.format("Percentage correct for %s: %.2f%% (%d/%d)", key, percent, nCorrect, nTested));
}
// Force a normal classifier update, to compensate for the fact we had to modify the 'live' classifier
updateClassification(false);
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class ClassifierBuilderPane method updateClassification.
private synchronized void updateClassification(boolean interactive) {
PathObjectHierarchy hierarchy = getHierarchy();
if (hierarchy == null) {
if (interactive)
Dialogs.showErrorMessage("Classification error", "No objects available to classify!");
btnSaveClassifier.setDisable(!classifier.isValid());
return;
}
List<String> features = featurePanel.getSelectedFeatures();
// If we've no features, default to trying to get
if (features.isEmpty() && interactive) {
selectAllFeatures();
features = featurePanel.getSelectedFeatures();
if (features.size() == 1)
Dialogs.showInfoNotification("Feature selection", "Classifier set to train using the only available feature");
else if (!features.isEmpty())
Dialogs.showInfoNotification("Feature selection", "Classifier set to train using all " + features.size() + " available features");
}
// If still got no features, we're rather stuck
if (features.isEmpty()) {
Dialogs.showErrorMessage("Classification error", "No features available to use for classification!");
btnSaveClassifier.setDisable(classifier == null || !classifier.isValid());
return;
}
updatingClassification = true;
// Get training map
double maxTrainingProportion = paramsUpdate.getIntParameterValue("maxTrainingPercent") / 100.;
long seed = paramsUpdate.getIntParameterValue("randomSeed");
SplitType splitType = (SplitType) paramsUpdate.getChoiceParameterValue("splitType");
Map<PathClass, List<PathObject>> map = getTrainingMap();
// Apply limit if needed
if (paramsUpdate.getBooleanParameterValue("limitTrainingToRepresentedClasses")) {
Set<PathClass> representedClasses = map.keySet();
for (List<PathObject> values : map.values()) {
Iterator<PathObject> iter = values.iterator();
while (iter.hasNext()) {
PathClass pathClass = iter.next().getPathClass();
if (pathClass != null && !representedClasses.contains(pathClass.getBaseClass()))
iter.remove();
}
}
}
// TODO: The order of entries in the map is not necessarily consistent (e.g. when a new annotation is added to the hierarchy -
// irrespective of whether or not it has a classification). Consequently, classifiers that rely on 'randomness' (e.g. random forests...)
// can give different results for the same training data. With 'auto-update' selected, this looks somewhat disturbing...
Map<PathClass, List<PathObject>> mapTraining = PathClassificationLabellingHelper.resampleClassificationMap(map, splitType, maxTrainingProportion, seed);
if (mapTraining.size() <= 1) {
logger.error("Training samples from at least two different classes required to train a classifier!");
updatingClassification = false;
return;
}
// Try to create a separate test map, if we can
Map<PathClass, List<PathObject>> mapTest = map;
boolean testOnTrainingData = true;
if (mapTraining != map) {
for (Entry<PathClass, List<PathObject>> entry : mapTraining.entrySet()) {
mapTest.get(entry.getKey()).removeAll(entry.getValue());
logger.info("Number of training samples for " + entry.getKey() + ": " + entry.getValue().size());
}
testOnTrainingData = false;
}
// Balance the classes for training, if necessary
if (paramsUpdate.getBooleanParameterValue("balanceClasses")) {
logger.debug("Balancing classes...");
int maxSize = -1;
for (List<PathObject> temp : mapTraining.values()) maxSize = Math.max(maxSize, temp.size());
Random random = new Random(seed);
for (PathClass key : mapTraining.keySet()) {
List<PathObject> temp = mapTraining.get(key);
int size = temp.size();
if (maxSize == size)
continue;
// Ensure a copy is made
List<PathObject> list = new ArrayList<>(temp);
for (int i = 0; i < maxSize - size; i++) {
list.add(temp.get(random.nextInt(size)));
}
mapTraining.put(key, list);
}
}
BackgroundClassificationTask task = new BackgroundClassificationTask(hierarchy, features, mapTraining, mapTest, testOnTrainingData);
qupath.submitShortTask(task);
// doClassification(hierarchy, features, mapTraining, mapTest, testOnTrainingData);
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class ClassifierBuilderPane method doClassification.
private void doClassification(final PathObjectHierarchy hierarchy, final List<String> features, final Map<PathClass, List<PathObject>> mapTraining, final Map<PathClass, List<PathObject>> mapTest, final boolean testOnTrainingData) {
if (!Platform.isFxApplicationThread())
Platform.runLater(() -> {
progressIndicator.setProgress(-1);
progressIndicator.setPrefSize(30, 30);
progressIndicator.setVisible(true);
});
long startTime = System.currentTimeMillis();
// Train classifier with requested normalization
Normalization normalization = (Normalization) paramsUpdate.getChoiceParameterValue("normalizationMethod");
String errorMessage = null;
boolean classifierChanged = classifier != lastClassifierCompleted;
try {
classifierChanged = classifier.updateClassifier(mapTraining, features, normalization) || classifierChanged;
} catch (Exception e) {
errorMessage = "Classifier training failed with message:\n" + e.getLocalizedMessage() + "\nPlease try again with different settings.";
e.printStackTrace();
}
if (classifier == null || !classifier.isValid()) {
updateClassifierSummary(errorMessage);
logger.error("Classifier is invalid!");
updatingClassification = false;
btnSaveClassifier.setDisable(classifier == null || !classifier.isValid());
return;
}
long middleTime = System.currentTimeMillis();
logger.info(String.format("Classifier training time: %.2f seconds", (middleTime - startTime) / 1000.));
// Create an intensity classifier, if required
PathObjectClassifier intensityClassifier = panelIntensities.getIntensityClassifier();
// Apply classifier to everything
Collection<PathObject> pathObjectsOrig = hierarchy.getDetectionObjects();
int nClassified = 0;
// Possible get proxy objects, depending on the thread we're on
Collection<PathObject> pathObjects;
if (Platform.isFxApplicationThread())
pathObjects = pathObjectsOrig;
else
pathObjects = pathObjectsOrig.stream().map(p -> new PathObjectClassificationProxy(p)).collect(Collectors.toList());
// Omit any objects that have already been classified by anything other than any of the target classes
if (paramsUpdate.getBooleanParameterValue("limitTrainingToRepresentedClasses")) {
Iterator<PathObject> iterator = pathObjects.iterator();
Set<PathClass> representedClasses = mapTraining.keySet();
while (iterator.hasNext()) {
PathClass currentClass = iterator.next().getPathClass();
if (currentClass != null && !representedClasses.contains(currentClass.getBaseClass()))
iterator.remove();
}
}
// In the event that we're using retained images, ensure we classify everything in our test set
if (retainedObjectsMap.size() > 1 && !mapTest.isEmpty()) {
for (Entry<PathClass, List<PathObject>> entry : mapTest.entrySet()) {
pathObjects.addAll(entry.getValue());
}
}
if (classifierChanged || hierarchyChanged) {
nClassified = classifier.classifyPathObjects(pathObjects);
} else {
logger.info("Main classifier unchanged...");
}
if (intensityClassifier != null)
intensityClassifier.classifyPathObjects(pathObjects);
// }
if (nClassified > 0) {
// qupath.getViewer().repaint();
} else if (classifierChanged || hierarchyChanged)
logger.error("Classification failed - no objects classified!");
long endTime = System.currentTimeMillis();
logger.info(String.format("Classification time: %.2f seconds", (endTime - middleTime) / 1000.));
// panelClassifier.setCursor(cursor);
completeClassification(hierarchy, pathObjects, pathObjectsOrig, mapTest, testOnTrainingData);
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class CreateTrainingImageCommand method promptToCreateTrainingImage.
/**
* Prompt to create a training image, based upon annotations throughout a project.
* @param project
* @param availableClasses
* @return the entry of the new training image, created within the project
*/
public static ProjectImageEntry<BufferedImage> promptToCreateTrainingImage(Project<BufferedImage> project, List<PathClass> availableClasses) {
if (project == null) {
Dialogs.showErrorMessage(NAME, "You need a project!");
return null;
}
if (availableClasses.isEmpty()) {
Dialogs.showErrorMessage(NAME, "Please ensure classifications are available in QuPath!");
return null;
}
List<PathClass> pathClasses = new ArrayList<>(availableClasses);
if (!pathClasses.contains(pathClass))
pathClass = pathClasses.get(0);
var params = new ParameterList().addEmptyParameter("Generates a single image from regions extracted from the project.").addEmptyParameter("Before running this command, add classified rectangle annotations to select the regions.").addChoiceParameter("pathClass", "Classification", pathClass, pathClasses, "Select classification for annotated regions").addIntParameter("maxWidth", "Preferred image width", maxWidth, "px", "Preferred maximum width of the training image, in pixels").addBooleanParameter("doZ", "Do z-stacks", doZ, "Take all slices of a z-stack, where possible").addBooleanParameter("rectanglesOnly", "Rectangles only", rectanglesOnly, "Only extract regions annotated with rectangles. Otherwise, the bounding box of all regions with the classification will be taken.").addEmptyParameter("Note this command requires images to have similar bit-depths/channels/pixel sizes for compatibility.");
if (!Dialogs.showParameterDialog(NAME, params))
return null;
pathClass = (PathClass) params.getChoiceParameterValue("pathClass");
maxWidth = params.getIntParameterValue("maxWidth");
doZ = params.getBooleanParameterValue("doZ");
rectanglesOnly = params.getBooleanParameterValue("rectanglesOnly");
var task = new Task<SparseImageServer>() {
@Override
protected SparseImageServer call() throws Exception {
return createSparseServer(project, pathClass, maxWidth, doZ, rectanglesOnly);
}
};
var dialog = new ProgressDialog(task);
dialog.setTitle(NAME);
dialog.setHeaderText("Creating training image...");
Executors.newSingleThreadExecutor().submit(task);
dialog.showAndWait();
try {
var server = task.get();
// var server = createSparseServer(project, pathClass, maxWidth, doZ, rectanglesOnly);
if (server == null || server.getManager().getRegions().isEmpty()) {
Dialogs.showErrorMessage("Sparse image server", "No suitable annotations found in the current project!");
return null;
}
var entry = ProjectCommands.addSingleImageToProject(project, server, null);
server.close();
project.syncChanges();
return entry;
} catch (Exception e) {
Dialogs.showErrorMessage("Sparse image server", e);
return null;
}
}
use of qupath.lib.objects.classes.PathClass in project qupath by qupath.
the class PixelClassifierPane method doClassification.
private void doClassification() {
// if (helper == null || helper.getFeatureServer() == null) {
// // updateFeatureCalculator();
// // updateClassifier();
// if (helper == null) {
// logger.error("No pixel classifier helper available!");
// return;
// }
// }
var imageData = qupath.getImageData();
if (imageData == null) {
if (!qupath.getViewers().stream().anyMatch(v -> v.getImageData() != null)) {
logger.debug("doClassification() called, but no images are open");
return;
}
}
var model = selectedClassifier.get();
if (model == null) {
Dialogs.showErrorNotification("Pixel classifier", "No classifier selected!");
return;
}
ClassifierTrainingData trainingData;
try {
var trainingImages = getTrainingImageData();
if (trainingImages.size() > 1)
logger.info("Creating training data from {} images", trainingImages.size());
trainingData = helper.createTrainingData(trainingImages);
} catch (Exception e) {
logger.error("Error when updating training data", e);
return;
}
if (trainingData == null) {
resetPieChart();
return;
}
// TODO: Optionally limit the number of training samples we use
// var trainData = classifier.createTrainData(matFeatures, matTargets);
// Ensure we seed the RNG for reproducibility
opencv_core.setRNGSeed(rngSeed);
// TODO: Prevent training K nearest neighbor with a huge number of samples (very slow!)
var actualMaxSamples = this.maxSamples;
var trainData = trainingData.getTrainData();
if (actualMaxSamples > 0 && trainData.getNTrainSamples() > actualMaxSamples)
trainData.setTrainTestSplit(actualMaxSamples, true);
else
trainData.shuffleTrainTest();
// System.err.println("Train: " + trainData.getTrainResponses());
// System.err.println("Test: " + trainData.getTestResponses());
// Apply normalization, if we need to
FeaturePreprocessor preprocessor = normalization.build(trainData.getTrainSamples(), false);
if (preprocessor.doesSomething()) {
preprocessingOp = ImageOps.ML.preprocessor(preprocessor);
} else
preprocessingOp = null;
var labels = trainingData.getLabelMap();
// Using getTrainNormCatResponses() causes confusion if classes are not represented
// var targets = trainData.getTrainNormCatResponses();
var targets = trainData.getTrainResponses();
IntBuffer buffer = targets.createBuffer();
int n = (int) targets.total();
var rawCounts = new int[labels.size()];
for (int i = 0; i < n; i++) {
rawCounts[buffer.get(i)] += 1;
}
Map<PathClass, Integer> counts = new LinkedHashMap<>();
for (var entry : labels.entrySet()) {
counts.put(entry.getKey(), rawCounts[entry.getValue()]);
}
updatePieChart(counts);
Mat weights = null;
if (reweightSamples) {
weights = new Mat(n, 1, opencv_core.CV_32FC1);
FloatIndexer bufferWeights = weights.createIndexer();
float[] weightArray = new float[rawCounts.length];
for (int i = 0; i < weightArray.length; i++) {
int c = rawCounts[i];
// weightArray[i] = c == 0 ? 1 : (float)1.f/c;
weightArray[i] = c == 0 ? 1 : (float) n / c;
}
for (int i = 0; i < n; i++) {
int label = buffer.get(i);
bufferWeights.put(i, weightArray[label]);
}
bufferWeights.release();
}
// Create TrainData in an appropriate format (e.g. labels or one-hot encoding)
var trainSamples = trainData.getTrainSamples();
var trainResponses = trainData.getTrainResponses();
preprocessor.apply(trainSamples, false);
trainData = model.createTrainData(trainSamples, trainResponses, weights, false);
logger.info("Training data: {} x {}, Target data: {} x {}", trainSamples.rows(), trainSamples.cols(), trainResponses.rows(), trainResponses.cols());
model.train(trainData);
// Calculate accuracy using whatever we can, as a rough guide to progress
var test = trainData.getTestSamples();
String testSet = "HELD-OUT TRAINING SET";
if (test.empty()) {
test = trainSamples;
testSet = "TRAINING SET";
} else {
preprocessor.apply(test, false);
buffer = trainData.getTestNormCatResponses().createBuffer();
}
var testResults = new Mat();
model.predict(test, testResults, null);
IntBuffer bufferResults = testResults.createBuffer();
int nTest = (int) testResults.rows();
int nCorrect = 0;
for (int i = 0; i < nTest; i++) {
if (bufferResults.get(i) == buffer.get(i))
nCorrect++;
}
logger.info("Current accuracy on the {}: {} %", testSet, GeneralTools.formatNumber(nCorrect * 100.0 / n, 1));
if (model instanceof RTreesClassifier) {
var trees = (RTreesClassifier) model;
if (trees.hasFeatureImportance() && imageData != null)
logVariableImportance(trees, helper.getFeatureOp().getChannels(imageData).stream().map(c -> c.getName()).collect(Collectors.toList()));
}
trainData.close();
var featureCalculator = helper.getFeatureOp();
if (preprocessingOp != null)
featureCalculator = featureCalculator.appendOps(preprocessingOp);
// TODO: CHECK IF INPUT SIZE SHOULD BE DEFINED
int inputWidth = 512;
int inputHeight = 512;
// int inputWidth = featureCalculator.getInputSize().getWidth();
// int inputHeight = featureCalculator.getInputSize().getHeight();
var cal = helper.getResolution();
var channelType = ImageServerMetadata.ChannelType.CLASSIFICATION;
if (model.supportsProbabilities()) {
channelType = selectedOutputType.get();
}
// Channels are needed for probability output (and work for classification as well)
var labels2 = new TreeMap<Integer, PathClass>();
for (var entry : labels.entrySet()) {
var previous = labels2.put(entry.getValue(), entry.getKey());
if (previous != null)
logger.warn("Duplicate label found! {} matches with {} and {}, only the latter be used", entry.getValue(), previous, entry.getKey());
}
var channels = PathClassifierTools.classificationLabelsToChannels(labels2, true);
PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputResolution(cal).inputShape(inputWidth, inputHeight).setChannelType(channelType).outputChannels(channels).build();
currentClassifier.set(PixelClassifiers.createClassifier(model, featureCalculator, metadata, true));
var overlay = PixelClassificationOverlay.create(qupath.getOverlayOptions(), currentClassifier.get(), getLivePredictionThreads());
replaceOverlay(overlay);
}
Aggregations