use of sc.fiji.labkit.ui.labeling.Labeling in project labkit-ui by juglab.
the class TrainableSegmentationSegmenter method train.
@Override
public void train(List<Pair<ImgPlus<?>, Labeling>> trainingData) {
try {
initFeatureSettings(trainingData);
List<String> classes = collectLabels(trainingData.stream().map(Pair::getB).collect(Collectors.toList()));
sc.fiji.labkit.pixel_classification.classification.Segmenter segmenter = new sc.fiji.labkit.pixel_classification.classification.Segmenter(context, classes, featureSettings, new FastRandomForest());
segmenter.setUseGpu(useGpu);
Training training = segmenter.training();
for (Pair<ImgPlus<?>, Labeling> pair : trainingData) trainStack(training, classes, pair.getB(), pair.getA(), segmenter.features());
training.train();
this.segmenter = segmenter;
} catch (RuntimeException e) {
Throwable cause = e.getCause();
if (cause instanceof WekaException && cause.getMessage().contains("Not enough training instances"))
throw new CancellationException("The training requires some labeled regions.");
throw e;
}
}
use of sc.fiji.labkit.ui.labeling.Labeling in project labkit-ui by juglab.
the class TrainableSegmentationSegmenter method getClassIndices.
private SparseRandomAccessIntType getClassIndices(Labeling labeling, List<String> classes) {
SparseRandomAccessIntType result = new SparseRandomAccessIntType(labeling, -1);
Map<Set<Label>, Integer> classIndices = new HashMap<>();
Function<Set<Label>, Integer> compute = set -> set.stream().mapToInt(label -> classes.indexOf(label.name())).filter(i -> i >= 0).min().orElse(-1);
Cursor<?> cursor = labeling.sparsityCursor();
RandomAccess<LabelingType<Label>> randomAccess = labeling.randomAccess();
RandomAccess<IntType> out = result.randomAccess();
while (cursor.hasNext()) {
cursor.fwd();
randomAccess.setPosition(cursor);
Set<Label> labels = randomAccess.get();
if (labels.isEmpty())
continue;
Integer classIndex = classIndices.computeIfAbsent(labels, compute);
out.setPosition(cursor);
out.get().set(classIndex);
}
return result;
}
use of sc.fiji.labkit.ui.labeling.Labeling in project labkit-ui by juglab.
the class MultiChannelMovieDemo method testInputImageImageForSegmentation.
@Test
public void testInputImageImageForSegmentation() {
DatasetInputImage inputImage = inputImage5d();
SegmentationModel segmentationModel = new DefaultSegmentationModel(new Context(), inputImage);
SegmentationItem segmenter = segmentationModel.segmenterList().addSegmenter(PixelClassificationPlugin.create());
Labeling labeling1 = labeling5d();
segmentationModel.imageLabelingModel().labeling().set(labeling1);
segmenter.train(Collections.singletonList(new ValuePair<>(inputImage.imageForSegmentation(), labeling1)));
RandomAccessibleInterval<ShortType> result = segmenter.results(segmentationModel.imageLabelingModel()).segmentation();
Labeling labeling = labeling5d();
LoopBuilder.setImages(labeling, result).forEachPixel((l, r) -> {
if (l.contains("foreground"))
assertEquals(1, r.get());
if (l.contains("background"))
assertEquals(0, r.get());
});
}
use of sc.fiji.labkit.ui.labeling.Labeling in project labkit-ui by juglab.
the class MultiChannelMovieDemo method labeling5d.
private Labeling labeling5d() {
Labeling labeling = Labeling.createEmpty(Arrays.asList("background", "foreground"), new FinalInterval(20, 10, 10, 20));
RandomAccess<LabelingType<Label>> ra = labeling.randomAccess();
ra.setPosition(new long[] { 1, 0, 0, 1 });
ra.get().add(labeling.getLabel("foreground"));
ra.setPosition(new long[] { 4, 0, 0, 1 });
ra.get().add(labeling.getLabel("foreground"));
ra.setPosition(new long[] { 5, 0, 0, 1 });
ra.get().add(labeling.getLabel("background"));
ra.setPosition(new long[] { 0, 0, 0, 1 });
ra.get().add(labeling.getLabel("background"));
return labeling;
}
use of sc.fiji.labkit.ui.labeling.Labeling in project labkit-ui by juglab.
the class SegmentationUseCaseTest method testMultiChannel.
@Test
public void testMultiChannel() throws InterruptedException {
Img<UnsignedByteType> img = ArrayImgs.unsignedBytes(new byte[] { -1, 0, -1, 0, -1, -1, 0, 0 }, 2, 2, 2);
ImgPlus<UnsignedByteType> imgPlus = new ImgPlus<>(img, "Image", new AxisType[] { Axes.X, Axes.Y, Axes.CHANNEL });
DatasetInputImage inputImage = new DatasetInputImage(imgPlus, BdvShowable.wrap(Views.hyperSlice(img, 2, 0)));
Labeling labeling = getLabeling();
SegmentationModel segmentationModel = new DefaultSegmentationModel(new Context(), inputImage);
ImageLabelingModel imageLabelingModel = segmentationModel.imageLabelingModel();
imageLabelingModel.labeling().set(labeling);
SegmentationItem segmenter = segmentationModel.segmenterList().addSegmenter(PixelClassificationPlugin.create());
segmenter.train(Collections.singletonList(new ValuePair<>(imgPlus, imageLabelingModel.labeling().get())));
RandomAccessibleInterval<ShortType> result = segmenter.results(imageLabelingModel).segmentation();
Iterator<ShortType> it = Views.iterable(result).iterator();
assertEquals(1, it.next().get());
assertEquals(0, it.next().get());
assertEquals(0, it.next().get());
assertEquals(0, it.next().get());
assertTrue(Intervals.equals(new FinalInterval(2, 2), result));
}
Aggregations