use of qupath.process.gui.commands.ml.PixelClassifierTraining.ClassifierTrainingData in project qupath by qupath.
the class PixelClassifierPane method doClassification.
private void doClassification() {
// if (helper == null || helper.getFeatureServer() == null) {
// // updateFeatureCalculator();
// // updateClassifier();
// if (helper == null) {
// logger.error("No pixel classifier helper available!");
// return;
// }
// }
var imageData = qupath.getImageData();
if (imageData == null) {
if (!qupath.getViewers().stream().anyMatch(v -> v.getImageData() != null)) {
logger.debug("doClassification() called, but no images are open");
return;
}
}
var model = selectedClassifier.get();
if (model == null) {
Dialogs.showErrorNotification("Pixel classifier", "No classifier selected!");
return;
}
ClassifierTrainingData trainingData;
try {
var trainingImages = getTrainingImageData();
if (trainingImages.size() > 1)
logger.info("Creating training data from {} images", trainingImages.size());
trainingData = helper.createTrainingData(trainingImages);
} catch (Exception e) {
logger.error("Error when updating training data", e);
return;
}
if (trainingData == null) {
resetPieChart();
return;
}
// TODO: Optionally limit the number of training samples we use
// var trainData = classifier.createTrainData(matFeatures, matTargets);
// Ensure we seed the RNG for reproducibility
opencv_core.setRNGSeed(rngSeed);
// TODO: Prevent training K nearest neighbor with a huge number of samples (very slow!)
var actualMaxSamples = this.maxSamples;
var trainData = trainingData.getTrainData();
if (actualMaxSamples > 0 && trainData.getNTrainSamples() > actualMaxSamples)
trainData.setTrainTestSplit(actualMaxSamples, true);
else
trainData.shuffleTrainTest();
// System.err.println("Train: " + trainData.getTrainResponses());
// System.err.println("Test: " + trainData.getTestResponses());
// Apply normalization, if we need to
FeaturePreprocessor preprocessor = normalization.build(trainData.getTrainSamples(), false);
if (preprocessor.doesSomething()) {
preprocessingOp = ImageOps.ML.preprocessor(preprocessor);
} else
preprocessingOp = null;
var labels = trainingData.getLabelMap();
// Using getTrainNormCatResponses() causes confusion if classes are not represented
// var targets = trainData.getTrainNormCatResponses();
var targets = trainData.getTrainResponses();
IntBuffer buffer = targets.createBuffer();
int n = (int) targets.total();
var rawCounts = new int[labels.size()];
for (int i = 0; i < n; i++) {
rawCounts[buffer.get(i)] += 1;
}
Map<PathClass, Integer> counts = new LinkedHashMap<>();
for (var entry : labels.entrySet()) {
counts.put(entry.getKey(), rawCounts[entry.getValue()]);
}
updatePieChart(counts);
Mat weights = null;
if (reweightSamples) {
weights = new Mat(n, 1, opencv_core.CV_32FC1);
FloatIndexer bufferWeights = weights.createIndexer();
float[] weightArray = new float[rawCounts.length];
for (int i = 0; i < weightArray.length; i++) {
int c = rawCounts[i];
// weightArray[i] = c == 0 ? 1 : (float)1.f/c;
weightArray[i] = c == 0 ? 1 : (float) n / c;
}
for (int i = 0; i < n; i++) {
int label = buffer.get(i);
bufferWeights.put(i, weightArray[label]);
}
bufferWeights.release();
}
// Create TrainData in an appropriate format (e.g. labels or one-hot encoding)
var trainSamples = trainData.getTrainSamples();
var trainResponses = trainData.getTrainResponses();
preprocessor.apply(trainSamples, false);
trainData = model.createTrainData(trainSamples, trainResponses, weights, false);
logger.info("Training data: {} x {}, Target data: {} x {}", trainSamples.rows(), trainSamples.cols(), trainResponses.rows(), trainResponses.cols());
model.train(trainData);
// Calculate accuracy using whatever we can, as a rough guide to progress
var test = trainData.getTestSamples();
String testSet = "HELD-OUT TRAINING SET";
if (test.empty()) {
test = trainSamples;
testSet = "TRAINING SET";
} else {
preprocessor.apply(test, false);
buffer = trainData.getTestNormCatResponses().createBuffer();
}
var testResults = new Mat();
model.predict(test, testResults, null);
IntBuffer bufferResults = testResults.createBuffer();
int nTest = (int) testResults.rows();
int nCorrect = 0;
for (int i = 0; i < nTest; i++) {
if (bufferResults.get(i) == buffer.get(i))
nCorrect++;
}
logger.info("Current accuracy on the {}: {} %", testSet, GeneralTools.formatNumber(nCorrect * 100.0 / n, 1));
if (model instanceof RTreesClassifier) {
var trees = (RTreesClassifier) model;
if (trees.hasFeatureImportance() && imageData != null)
logVariableImportance(trees, helper.getFeatureOp().getChannels(imageData).stream().map(c -> c.getName()).collect(Collectors.toList()));
}
trainData.close();
var featureCalculator = helper.getFeatureOp();
if (preprocessingOp != null)
featureCalculator = featureCalculator.appendOps(preprocessingOp);
// TODO: CHECK IF INPUT SIZE SHOULD BE DEFINED
int inputWidth = 512;
int inputHeight = 512;
// int inputWidth = featureCalculator.getInputSize().getWidth();
// int inputHeight = featureCalculator.getInputSize().getHeight();
var cal = helper.getResolution();
var channelType = ImageServerMetadata.ChannelType.CLASSIFICATION;
if (model.supportsProbabilities()) {
channelType = selectedOutputType.get();
}
// Channels are needed for probability output (and work for classification as well)
var labels2 = new TreeMap<Integer, PathClass>();
for (var entry : labels.entrySet()) {
var previous = labels2.put(entry.getValue(), entry.getKey());
if (previous != null)
logger.warn("Duplicate label found! {} matches with {} and {}, only the latter be used", entry.getValue(), previous, entry.getKey());
}
var channels = PathClassifierTools.classificationLabelsToChannels(labels2, true);
PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputResolution(cal).inputShape(inputWidth, inputHeight).setChannelType(channelType).outputChannels(channels).build();
currentClassifier.set(PixelClassifiers.createClassifier(model, featureCalculator, metadata, true));
var overlay = PixelClassificationOverlay.create(qupath.getOverlayOptions(), currentClassifier.get(), getLivePredictionThreads());
replaceOverlay(overlay);
}
Aggregations