use of ai.djl.modality.cv.output.Rectangle in project djl-demo by deepjavalibrary.
the class MultiEngine method detectPersonWithPyTorchModel.
private static Image detectPersonWithPyTorchModel(Image img) throws MalformedModelException, ModelNotFoundException, IOException, TranslateException {
// Criteria object to load the model from model zoo
Criteria<Image, DetectedObjects> criteria = Criteria.builder().optApplication(Application.CV.OBJECT_DETECTION).setTypes(Image.class, DetectedObjects.class).optProgress(new ProgressBar()).optFilter("size", "300").optFilter("backbone", "resnet50").optFilter("dataset", "coco").optEngine(// Use PyTorch engine
"PyTorch").build();
// Inference call to detect the person form the image.
DetectedObjects detectedObjects;
try (ZooModel<Image, DetectedObjects> ssd = criteria.loadModel();
Predictor<Image, DetectedObjects> predictor = ssd.newPredictor()) {
detectedObjects = predictor.predict(img);
}
// Get the first resulting image of the person and return it
List<DetectedObjects.DetectedObject> items = detectedObjects.items();
for (DetectedObjects.DetectedObject item : items) {
if ("person".equals(item.getClassName())) {
Rectangle rect = item.getBoundingBox().getBounds();
int width = img.getWidth();
int height = img.getHeight();
return img.getSubimage((int) (rect.getX() * width), (int) (rect.getY() * height), (int) (rect.getWidth() * width), (int) (rect.getHeight() * height));
}
}
return null;
}
use of ai.djl.modality.cv.output.Rectangle in project djl-demo by deepjavalibrary.
the class FaceDetectionActivity method detect.
private String detect(Context mContext, String assetPath, ImageView imageView) throws IOException, TranslateException {
String msg = "";
Bitmap bitmap = ImageUtil.getBitmap(mContext, assetPath);
msg = msg + "Image size = " + bitmap.getWidth() + "x" + bitmap.getHeight() + "\n";
long startTime = System.currentTimeMillis();
Image img = ImageUtil.getImage(mContext, assetPath);
DetectedObjects detection = predictor.predict(img);
msg = msg + "Face detected = " + detection.getNumberOfObjects() + "\n";
msg = msg + "Time: " + (System.currentTimeMillis() - startTime) + " ms";
Bitmap drawBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true);
Canvas canvas = new Canvas(drawBitmap);
Paint paint = new Paint();
paint.setColor(Color.GREEN);
paint.setStyle(Paint.Style.STROKE);
paint.setStrokeWidth(5);
List<DetectedObjects.DetectedObject> list = detection.items();
for (DetectedObjects.DetectedObject face : list) {
BoundingBox box = face.getBoundingBox();
Rectangle rectangle = box.getBounds();
int left = (int) (rectangle.getX() * (double) img.getWidth());
int top = (int) (rectangle.getY() * (double) img.getHeight());
int right = left + (int) (rectangle.getWidth() * (double) img.getWidth());
int bottom = top + (int) (rectangle.getHeight() * (double) img.getHeight());
canvas.drawRect(left, top, right, bottom, paint);
}
imageView.post(() -> imageView.setImageBitmap(drawBitmap));
return msg;
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class SingleShotDetectionTranslator method processOutput.
/**
* {@inheritDoc}
*/
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
float[] classIds = list.get(0).toFloatArray();
float[] probabilities = list.get(1).toFloatArray();
NDArray boundingBoxes = list.get(2);
List<String> retNames = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();
for (int i = 0; i < classIds.length; ++i) {
int classId = (int) classIds[i];
double probability = probabilities[i];
// classId starts from 0, -1 means background
if (classId >= 0 && probability > threshold) {
if (classId >= classes.size()) {
throw new AssertionError("Unexpected index: " + classId);
}
String className = classes.get(classId);
float[] box = boundingBoxes.get(i).toFloatArray();
// rescale box coordinates by imageWidth and imageHeight
double x = imageWidth > 0 ? box[0] / imageWidth : box[0];
double y = imageHeight > 0 ? box[1] / imageHeight : box[1];
double w = imageWidth > 0 ? box[2] / imageWidth - x : box[2] - x;
double h = imageHeight > 0 ? box[3] / imageHeight - y : box[3] - y;
Rectangle rect = new Rectangle(x, y, w, h);
retNames.add(className);
retProbs.add(probability);
retBB.add(rect);
}
}
return new DetectedObjects(retNames, retProbs, retBB);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class YoloV5Translator method nms.
protected DetectedObjects nms(List<IntermediateResult> list) {
List<String> retClasses = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();
for (int k = 0; k < classes.size(); k++) {
// 1.find max confidence per class
PriorityQueue<IntermediateResult> pq = new PriorityQueue<>(50, (lhs, rhs) -> {
// queue.
return Double.compare(rhs.getConfidence(), lhs.getConfidence());
});
for (IntermediateResult intermediateResult : list) {
if (intermediateResult.getDetectedClass() == k) {
pq.add(intermediateResult);
}
}
// 2.do non maximum suppression
while (pq.size() > 0) {
// insert detection with max confidence
IntermediateResult[] a = new IntermediateResult[pq.size()];
IntermediateResult[] detections = pq.toArray(a);
Rectangle rec = detections[0].getLocation();
retClasses.add(detections[0].id);
retProbs.add(detections[0].confidence);
retBB.add(new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight()));
pq.clear();
for (int j = 1; j < detections.length; j++) {
IntermediateResult detection = detections[j];
Rectangle location = detection.getLocation();
if (boxIou(rec, location) < nmsThreshold) {
pq.add(detection);
}
}
}
}
return new DetectedObjects(retClasses, retProbs, retBB);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class YoloV5Translator method processFromBoxOutput.
private DetectedObjects processFromBoxOutput(NDList list) {
float[] flattened = list.get(0).toFloatArray();
ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();
int sizeClasses = classes.size();
int stride = 5 + sizeClasses;
int size = flattened.length / stride;
for (int i = 0; i < size; i++) {
int indexBase = i * stride;
float maxClass = 0;
int maxIndex = 0;
for (int c = 0; c < sizeClasses; c++) {
if (flattened[indexBase + c + 5] > maxClass) {
maxClass = flattened[indexBase + c + 5];
maxIndex = c;
}
}
float score = maxClass * flattened[indexBase + 4];
if (score > threshold) {
float xPos = flattened[indexBase];
float yPos = flattened[indexBase + 1];
float w = flattened[indexBase + 2];
float h = flattened[indexBase + 3];
Rectangle rect = new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
intermediateResults.add(new IntermediateResult(classes.get(maxIndex), score, maxIndex, rect));
}
}
return nms(intermediateResults);
}
Aggregations