use of qupath.lib.classifiers.Normalization in project qupath by qupath.
the class OpenCvClassifier method readExternal.
@SuppressWarnings("unchecked")
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
long version = in.readLong();
if (version < 1 || version > 2)
throw new IOException("Unsupported version!");
timestamp = in.readLong();
pathClasses = (List<PathClass>) in.readObject();
// Ensure we have correct, single entries
if (pathClasses != null) {
for (int i = 0; i < pathClasses.size(); i++) {
pathClasses.set(i, PathClassFactory.getSingletonPathClass(pathClasses.get(i)));
}
}
normScale = (double[]) in.readObject();
normOffset = (double[]) in.readObject();
measurements = (List<String>) in.readObject();
arrayTraining = (float[]) in.readObject();
arrayResponses = (int[]) in.readObject();
if (version == 2) {
String method = (String) in.readObject();
for (Normalization n : Normalization.values()) {
if (n.toString().equals(method)) {
normalization = n;
break;
}
}
// normalization = Normalization.valueOf((String)in.readObject());
}
if (arrayTraining != null && arrayResponses != null) {
createAndTrainClassifier();
}
}
use of qupath.lib.classifiers.Normalization in project qupath by qupath.
the class ClassifierBuilderPane method crossValidateAcrossImages.
private void crossValidateAcrossImages() {
// Try to put the current image data information into the tempMap, which stores training data separated by image path
updateRetainedObjectsMap();
Map<String, Map<PathClass, List<PathObject>>> tempMap = new LinkedHashMap<>(retainedObjectsMap.getMap());
Normalization normalization = (Normalization) paramsUpdate.getChoiceParameterValue("normalizationMethod");
for (String key : tempMap.keySet()) {
Map<PathClass, List<PathObject>> validationMap = tempMap.get(key);
Map<PathClass, List<PathObject>> trainingMap = new LinkedHashMap<>();
for (Entry<String, Map<PathClass, List<PathObject>>> entry : tempMap.entrySet()) {
if (entry.getKey().equals(key))
continue;
for (Entry<PathClass, List<PathObject>> entry2 : entry.getValue().entrySet()) {
if (trainingMap.containsKey(entry2.getKey())) {
trainingMap.get(entry2.getKey()).addAll(entry2.getValue());
} else {
trainingMap.put(entry2.getKey(), new ArrayList<>(entry2.getValue()));
}
}
}
// Perform subsampling
SplitType splitType = (SplitType) paramsUpdate.getChoiceParameterValue("splitType");
double maxTrainingProportion = paramsUpdate.getIntParameterValue("maxTrainingPercent") / 100.;
long seed = paramsUpdate.getIntParameterValue("randomSeed");
trainingMap = PathClassificationLabellingHelper.resampleClassificationMap(trainingMap, splitType, maxTrainingProportion, seed);
// Get the current classifier - unfortunately, there's no easy way to duplicate/create a new one,
// so we are left working with the 'live' classifier
PathObjectClassifier classifier = (T) comboClassifiers.getSelectionModel().getSelectedItem();
classifier.updateClassifier(trainingMap, featurePanel.getSelectedFeatures(), normalization);
int nCorrect = 0;
int nTested = 0;
for (Entry<PathClass, List<PathObject>> entryValidation : validationMap.entrySet()) {
classifier.classifyPathObjects(entryValidation.getValue());
for (PathObject temp : entryValidation.getValue()) {
if (entryValidation.getKey().equals(temp.getPathClass()))
nCorrect++;
nTested++;
}
}
double percent = nCorrect * 100.0 / nTested;
logger.info(String.format("Percentage correct for %s: %.2f%%", key, percent));
System.err.println(String.format("Percentage correct for %s: %.2f%% (%d/%d)", key, percent, nCorrect, nTested));
}
// Force a normal classifier update, to compensate for the fact we had to modify the 'live' classifier
updateClassification(false);
}
use of qupath.lib.classifiers.Normalization in project qupath by qupath.
the class ClassifierBuilderPane method doClassification.
private void doClassification(final PathObjectHierarchy hierarchy, final List<String> features, final Map<PathClass, List<PathObject>> mapTraining, final Map<PathClass, List<PathObject>> mapTest, final boolean testOnTrainingData) {
if (!Platform.isFxApplicationThread())
Platform.runLater(() -> {
progressIndicator.setProgress(-1);
progressIndicator.setPrefSize(30, 30);
progressIndicator.setVisible(true);
});
long startTime = System.currentTimeMillis();
// Train classifier with requested normalization
Normalization normalization = (Normalization) paramsUpdate.getChoiceParameterValue("normalizationMethod");
String errorMessage = null;
boolean classifierChanged = classifier != lastClassifierCompleted;
try {
classifierChanged = classifier.updateClassifier(mapTraining, features, normalization) || classifierChanged;
} catch (Exception e) {
errorMessage = "Classifier training failed with message:\n" + e.getLocalizedMessage() + "\nPlease try again with different settings.";
e.printStackTrace();
}
if (classifier == null || !classifier.isValid()) {
updateClassifierSummary(errorMessage);
logger.error("Classifier is invalid!");
updatingClassification = false;
btnSaveClassifier.setDisable(classifier == null || !classifier.isValid());
return;
}
long middleTime = System.currentTimeMillis();
logger.info(String.format("Classifier training time: %.2f seconds", (middleTime - startTime) / 1000.));
// Create an intensity classifier, if required
PathObjectClassifier intensityClassifier = panelIntensities.getIntensityClassifier();
// Apply classifier to everything
Collection<PathObject> pathObjectsOrig = hierarchy.getDetectionObjects();
int nClassified = 0;
// Possible get proxy objects, depending on the thread we're on
Collection<PathObject> pathObjects;
if (Platform.isFxApplicationThread())
pathObjects = pathObjectsOrig;
else
pathObjects = pathObjectsOrig.stream().map(p -> new PathObjectClassificationProxy(p)).collect(Collectors.toList());
// Omit any objects that have already been classified by anything other than any of the target classes
if (paramsUpdate.getBooleanParameterValue("limitTrainingToRepresentedClasses")) {
Iterator<PathObject> iterator = pathObjects.iterator();
Set<PathClass> representedClasses = mapTraining.keySet();
while (iterator.hasNext()) {
PathClass currentClass = iterator.next().getPathClass();
if (currentClass != null && !representedClasses.contains(currentClass.getBaseClass()))
iterator.remove();
}
}
// In the event that we're using retained images, ensure we classify everything in our test set
if (retainedObjectsMap.size() > 1 && !mapTest.isEmpty()) {
for (Entry<PathClass, List<PathObject>> entry : mapTest.entrySet()) {
pathObjects.addAll(entry.getValue());
}
}
if (classifierChanged || hierarchyChanged) {
nClassified = classifier.classifyPathObjects(pathObjects);
} else {
logger.info("Main classifier unchanged...");
}
if (intensityClassifier != null)
intensityClassifier.classifyPathObjects(pathObjects);
// }
if (nClassified > 0) {
// qupath.getViewer().repaint();
} else if (classifierChanged || hierarchyChanged)
logger.error("Classification failed - no objects classified!");
long endTime = System.currentTimeMillis();
logger.info(String.format("Classification time: %.2f seconds", (endTime - middleTime) / 1000.));
// panelClassifier.setCursor(cursor);
completeClassification(hierarchy, pathObjects, pathObjectsOrig, mapTest, testOnTrainingData);
}
use of qupath.lib.classifiers.Normalization in project qupath by qupath.
the class PixelClassifierPane method doClassification.
private void doClassification() {
// if (helper == null || helper.getFeatureServer() == null) {
// // updateFeatureCalculator();
// // updateClassifier();
// if (helper == null) {
// logger.error("No pixel classifier helper available!");
// return;
// }
// }
var imageData = qupath.getImageData();
if (imageData == null) {
if (!qupath.getViewers().stream().anyMatch(v -> v.getImageData() != null)) {
logger.debug("doClassification() called, but no images are open");
return;
}
}
var model = selectedClassifier.get();
if (model == null) {
Dialogs.showErrorNotification("Pixel classifier", "No classifier selected!");
return;
}
ClassifierTrainingData trainingData;
try {
var trainingImages = getTrainingImageData();
if (trainingImages.size() > 1)
logger.info("Creating training data from {} images", trainingImages.size());
trainingData = helper.createTrainingData(trainingImages);
} catch (Exception e) {
logger.error("Error when updating training data", e);
return;
}
if (trainingData == null) {
resetPieChart();
return;
}
// TODO: Optionally limit the number of training samples we use
// var trainData = classifier.createTrainData(matFeatures, matTargets);
// Ensure we seed the RNG for reproducibility
opencv_core.setRNGSeed(rngSeed);
// TODO: Prevent training K nearest neighbor with a huge number of samples (very slow!)
var actualMaxSamples = this.maxSamples;
var trainData = trainingData.getTrainData();
if (actualMaxSamples > 0 && trainData.getNTrainSamples() > actualMaxSamples)
trainData.setTrainTestSplit(actualMaxSamples, true);
else
trainData.shuffleTrainTest();
// System.err.println("Train: " + trainData.getTrainResponses());
// System.err.println("Test: " + trainData.getTestResponses());
// Apply normalization, if we need to
FeaturePreprocessor preprocessor = normalization.build(trainData.getTrainSamples(), false);
if (preprocessor.doesSomething()) {
preprocessingOp = ImageOps.ML.preprocessor(preprocessor);
} else
preprocessingOp = null;
var labels = trainingData.getLabelMap();
// Using getTrainNormCatResponses() causes confusion if classes are not represented
// var targets = trainData.getTrainNormCatResponses();
var targets = trainData.getTrainResponses();
IntBuffer buffer = targets.createBuffer();
int n = (int) targets.total();
var rawCounts = new int[labels.size()];
for (int i = 0; i < n; i++) {
rawCounts[buffer.get(i)] += 1;
}
Map<PathClass, Integer> counts = new LinkedHashMap<>();
for (var entry : labels.entrySet()) {
counts.put(entry.getKey(), rawCounts[entry.getValue()]);
}
updatePieChart(counts);
Mat weights = null;
if (reweightSamples) {
weights = new Mat(n, 1, opencv_core.CV_32FC1);
FloatIndexer bufferWeights = weights.createIndexer();
float[] weightArray = new float[rawCounts.length];
for (int i = 0; i < weightArray.length; i++) {
int c = rawCounts[i];
// weightArray[i] = c == 0 ? 1 : (float)1.f/c;
weightArray[i] = c == 0 ? 1 : (float) n / c;
}
for (int i = 0; i < n; i++) {
int label = buffer.get(i);
bufferWeights.put(i, weightArray[label]);
}
bufferWeights.release();
}
// Create TrainData in an appropriate format (e.g. labels or one-hot encoding)
var trainSamples = trainData.getTrainSamples();
var trainResponses = trainData.getTrainResponses();
preprocessor.apply(trainSamples, false);
trainData = model.createTrainData(trainSamples, trainResponses, weights, false);
logger.info("Training data: {} x {}, Target data: {} x {}", trainSamples.rows(), trainSamples.cols(), trainResponses.rows(), trainResponses.cols());
model.train(trainData);
// Calculate accuracy using whatever we can, as a rough guide to progress
var test = trainData.getTestSamples();
String testSet = "HELD-OUT TRAINING SET";
if (test.empty()) {
test = trainSamples;
testSet = "TRAINING SET";
} else {
preprocessor.apply(test, false);
buffer = trainData.getTestNormCatResponses().createBuffer();
}
var testResults = new Mat();
model.predict(test, testResults, null);
IntBuffer bufferResults = testResults.createBuffer();
int nTest = (int) testResults.rows();
int nCorrect = 0;
for (int i = 0; i < nTest; i++) {
if (bufferResults.get(i) == buffer.get(i))
nCorrect++;
}
logger.info("Current accuracy on the {}: {} %", testSet, GeneralTools.formatNumber(nCorrect * 100.0 / n, 1));
if (model instanceof RTreesClassifier) {
var trees = (RTreesClassifier) model;
if (trees.hasFeatureImportance() && imageData != null)
logVariableImportance(trees, helper.getFeatureOp().getChannels(imageData).stream().map(c -> c.getName()).collect(Collectors.toList()));
}
trainData.close();
var featureCalculator = helper.getFeatureOp();
if (preprocessingOp != null)
featureCalculator = featureCalculator.appendOps(preprocessingOp);
// TODO: CHECK IF INPUT SIZE SHOULD BE DEFINED
int inputWidth = 512;
int inputHeight = 512;
// int inputWidth = featureCalculator.getInputSize().getWidth();
// int inputHeight = featureCalculator.getInputSize().getHeight();
var cal = helper.getResolution();
var channelType = ImageServerMetadata.ChannelType.CLASSIFICATION;
if (model.supportsProbabilities()) {
channelType = selectedOutputType.get();
}
// Channels are needed for probability output (and work for classification as well)
var labels2 = new TreeMap<Integer, PathClass>();
for (var entry : labels.entrySet()) {
var previous = labels2.put(entry.getValue(), entry.getKey());
if (previous != null)
logger.warn("Duplicate label found! {} matches with {} and {}, only the latter be used", entry.getValue(), previous, entry.getKey());
}
var channels = PathClassifierTools.classificationLabelsToChannels(labels2, true);
PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputResolution(cal).inputShape(inputWidth, inputHeight).setChannelType(channelType).outputChannels(channels).build();
currentClassifier.set(PixelClassifiers.createClassifier(model, featureCalculator, metadata, true));
var overlay = PixelClassificationOverlay.create(qupath.getOverlayOptions(), currentClassifier.get(), getLivePredictionThreads());
replaceOverlay(overlay);
}
Aggregations