use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class PpFaceDetectionTranslator method processOutput.
/**
* {@inheritDoc}
*/
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
NDArray result = list.singletonOrThrow();
float[] probabilities = result.get(":,1").toFloatArray();
List<String> names = new ArrayList<>();
List<Double> prob = new ArrayList<>();
List<BoundingBox> boxes = new ArrayList<>();
for (int i = 0; i < probabilities.length; i++) {
if (probabilities[i] >= threshold) {
float[] array = result.get(i).toFloatArray();
names.add(className.get((int) array[0]));
prob.add((double) probabilities[i]);
boxes.add(new Rectangle(array[2], array[3], array[4] - array[2], array[5] - array[3]));
}
}
return new DetectedObjects(names, prob, boxes);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class PpWordDetectionTranslator method processOutput.
/**
* {@inheritDoc}
*/
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws FileNotFoundException {
NDArray result = list.singletonOrThrow();
ImageFactory factory = ImageFactory.getInstance();
List<BoundingBox> boxes;
// faster mechanism
if (factory instanceof OpenCVImageFactory) {
result = result.squeeze(0).toType(DataType.UINT8, true);
Image image = factory.fromNDArray(result);
boxes = image.findBoundingBoxes().parallelStream().filter(box -> {
Rectangle rect = (Rectangle) box;
return rect.getWidth() * image.getWidth() > 5 || rect.getHeight() * image.getHeight() > 5;
}).collect(Collectors.toList());
} else {
result = result.squeeze().mul(255f).toType(DataType.UINT8, true).neq(0);
boolean[] flattened = result.toBooleanArray();
Shape shape = result.getShape();
int w = (int) shape.get(0);
int h = (int) shape.get(1);
boolean[][] grid = new boolean[w][h];
IntStream.range(0, flattened.length).parallel().forEach(i -> grid[i / h][i % h] = flattened[i]);
boxes = new BoundFinder(grid).getBoxes();
}
List<String> names = new ArrayList<>();
List<Double> probs = new ArrayList<>();
int boxSize = boxes.size();
for (int i = 0; i < boxSize; i++) {
names.add("word");
probs.add(1.0);
}
return new DetectedObjects(names, probs, boxes);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class YoloTranslator method processOutput.
/**
* {@inheritDoc}
*/
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
int[] classIndices = list.get(0).toType(DataType.INT32, true).flatten().toIntArray();
double[] probs = list.get(1).toType(DataType.FLOAT64, true).flatten().toDoubleArray();
NDArray boundingBoxes = list.get(2);
int detected = Math.toIntExact(probs.length);
NDArray xMin = boundingBoxes.get(":, 0").clip(0, imageWidth).div(imageWidth);
NDArray yMin = boundingBoxes.get(":, 1").clip(0, imageHeight).div(imageHeight);
NDArray xMax = boundingBoxes.get(":, 2").clip(0, imageWidth).div(imageWidth);
NDArray yMax = boundingBoxes.get(":, 3").clip(0, imageHeight).div(imageHeight);
float[] boxX = xMin.toFloatArray();
float[] boxY = yMin.toFloatArray();
float[] boxWidth = xMax.sub(xMin).toFloatArray();
float[] boxHeight = yMax.sub(yMin).toFloatArray();
List<String> retClasses = new ArrayList<>(detected);
List<Double> retProbs = new ArrayList<>(detected);
List<BoundingBox> retBB = new ArrayList<>(detected);
for (int i = 0; i < detected; i++) {
if (classIndices[i] < 0 || probs[i] < threshold) {
continue;
}
retClasses.add(classes.get(classIndices[i]));
retProbs.add(probs[i]);
Rectangle rect;
if (applyRatio) {
rect = new Rectangle(boxX[i] / imageWidth, boxY[i] / imageHeight, boxWidth[i] / imageWidth, boxHeight[i] / imageHeight);
} else {
rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]);
}
retBB.add(rect);
}
return new DetectedObjects(retClasses, retProbs, retBB);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class BananaDetection method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train");
break;
case TEST:
usagePath = Paths.get("test");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
Type mapType = new TypeToken<Map<String, List<Float>>>() {
}.getType();
Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
String imgName = entry.getKey();
imagePaths.add(usagePath.resolve(imgName));
List<Float> label = entry.getValue();
long objectClass = label.get(0).longValue();
Rectangle objectLocation = new Rectangle(new Point(label.get(1), label.get(2)), label.get(3), label.get(4));
labels.add(objectClass, objectLocation);
}
}
prepared = true;
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class CocoDetection method getLabels.
private PairList<Long, Rectangle> getLabels(CocoUtils coco, long imageId) {
List<Long> annotationIds = coco.getAnnotationIdByImageId(imageId);
if (annotationIds == null) {
return new PairList<>();
}
PairList<Long, Rectangle> label = new PairList<>(annotationIds.size());
for (long annotationId : annotationIds) {
CocoMetadata.Annotation annotation = coco.getAnnotationById(annotationId);
if (annotation.getArea() > 0) {
double[] box = annotation.getBoundingBox();
long labelClass = coco.mapCategoryId(annotation.getCategoryId());
Rectangle objectLocation = new Rectangle(new Point(box[0], box[1]), box[2], box[3]);
label.add(labelClass, objectLocation);
}
}
return label;
}
Aggregations