Search in sources :

Example 1 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl-demo by deepjavalibrary.

the class MultiEngine method detectPersonWithPyTorchModel.

private static Image detectPersonWithPyTorchModel(Image img) throws MalformedModelException, ModelNotFoundException, IOException, TranslateException {
    // Criteria object to load the model from model zoo
    Criteria<Image, DetectedObjects> criteria = Criteria.builder().optApplication(Application.CV.OBJECT_DETECTION).setTypes(Image.class, DetectedObjects.class).optProgress(new ProgressBar()).optFilter("size", "300").optFilter("backbone", "resnet50").optFilter("dataset", "coco").optEngine(// Use PyTorch engine
    "PyTorch").build();
    // Inference call to detect the person form the image.
    DetectedObjects detectedObjects;
    try (ZooModel<Image, DetectedObjects> ssd = criteria.loadModel();
        Predictor<Image, DetectedObjects> predictor = ssd.newPredictor()) {
        detectedObjects = predictor.predict(img);
    }
    // Get the first resulting image of the person and return it
    List<DetectedObjects.DetectedObject> items = detectedObjects.items();
    for (DetectedObjects.DetectedObject item : items) {
        if ("person".equals(item.getClassName())) {
            Rectangle rect = item.getBoundingBox().getBounds();
            int width = img.getWidth();
            int height = img.getHeight();
            return img.getSubimage((int) (rect.getX() * width), (int) (rect.getY() * height), (int) (rect.getWidth() * width), (int) (rect.getHeight() * height));
        }
    }
    return null;
}
Also used : Rectangle(ai.djl.modality.cv.output.Rectangle) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) Image(ai.djl.modality.cv.Image) ProgressBar(ai.djl.training.util.ProgressBar)

Example 2 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl-demo by deepjavalibrary.

the class FaceDetectionActivity method detect.

private String detect(Context mContext, String assetPath, ImageView imageView) throws IOException, TranslateException {
    String msg = "";
    Bitmap bitmap = ImageUtil.getBitmap(mContext, assetPath);
    msg = msg + "Image size = " + bitmap.getWidth() + "x" + bitmap.getHeight() + "\n";
    long startTime = System.currentTimeMillis();
    Image img = ImageUtil.getImage(mContext, assetPath);
    DetectedObjects detection = predictor.predict(img);
    msg = msg + "Face detected = " + detection.getNumberOfObjects() + "\n";
    msg = msg + "Time: " + (System.currentTimeMillis() - startTime) + " ms";
    Bitmap drawBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true);
    Canvas canvas = new Canvas(drawBitmap);
    Paint paint = new Paint();
    paint.setColor(Color.GREEN);
    paint.setStyle(Paint.Style.STROKE);
    paint.setStrokeWidth(5);
    List<DetectedObjects.DetectedObject> list = detection.items();
    for (DetectedObjects.DetectedObject face : list) {
        BoundingBox box = face.getBoundingBox();
        Rectangle rectangle = box.getBounds();
        int left = (int) (rectangle.getX() * (double) img.getWidth());
        int top = (int) (rectangle.getY() * (double) img.getHeight());
        int right = left + (int) (rectangle.getWidth() * (double) img.getWidth());
        int bottom = top + (int) (rectangle.getHeight() * (double) img.getHeight());
        canvas.drawRect(left, top, right, bottom, paint);
    }
    imageView.post(() -> imageView.setImageBitmap(drawBitmap));
    return msg;
}
Also used : Canvas(android.graphics.Canvas) Rectangle(ai.djl.modality.cv.output.Rectangle) FaceDetectedObjects(ai.djl.examples.detection.domain.FaceDetectedObjects) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) Paint(android.graphics.Paint) Image(ai.djl.modality.cv.Image) SuppressLint(android.annotation.SuppressLint) Paint(android.graphics.Paint) Bitmap(android.graphics.Bitmap) BoundingBox(ai.djl.modality.cv.output.BoundingBox)

Example 3 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class SingleShotDetectionTranslator method processOutput.

/**
 * {@inheritDoc}
 */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
    float[] classIds = list.get(0).toFloatArray();
    float[] probabilities = list.get(1).toFloatArray();
    NDArray boundingBoxes = list.get(2);
    List<String> retNames = new ArrayList<>();
    List<Double> retProbs = new ArrayList<>();
    List<BoundingBox> retBB = new ArrayList<>();
    for (int i = 0; i < classIds.length; ++i) {
        int classId = (int) classIds[i];
        double probability = probabilities[i];
        // classId starts from 0, -1 means background
        if (classId >= 0 && probability > threshold) {
            if (classId >= classes.size()) {
                throw new AssertionError("Unexpected index: " + classId);
            }
            String className = classes.get(classId);
            float[] box = boundingBoxes.get(i).toFloatArray();
            // rescale box coordinates by imageWidth and imageHeight
            double x = imageWidth > 0 ? box[0] / imageWidth : box[0];
            double y = imageHeight > 0 ? box[1] / imageHeight : box[1];
            double w = imageWidth > 0 ? box[2] / imageWidth - x : box[2] - x;
            double h = imageHeight > 0 ? box[3] / imageHeight - y : box[3] - y;
            Rectangle rect = new Rectangle(x, y, w, h);
            retNames.add(className);
            retProbs.add(probability);
            retBB.add(rect);
        }
    }
    return new DetectedObjects(retNames, retProbs, retBB);
}
Also used : ArrayList(java.util.ArrayList) Rectangle(ai.djl.modality.cv.output.Rectangle) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) BoundingBox(ai.djl.modality.cv.output.BoundingBox) NDArray(ai.djl.ndarray.NDArray)

Example 4 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class YoloV5Translator method nms.

protected DetectedObjects nms(List<IntermediateResult> list) {
    List<String> retClasses = new ArrayList<>();
    List<Double> retProbs = new ArrayList<>();
    List<BoundingBox> retBB = new ArrayList<>();
    for (int k = 0; k < classes.size(); k++) {
        // 1.find max confidence per class
        PriorityQueue<IntermediateResult> pq = new PriorityQueue<>(50, (lhs, rhs) -> {
            // queue.
            return Double.compare(rhs.getConfidence(), lhs.getConfidence());
        });
        for (IntermediateResult intermediateResult : list) {
            if (intermediateResult.getDetectedClass() == k) {
                pq.add(intermediateResult);
            }
        }
        // 2.do non maximum suppression
        while (pq.size() > 0) {
            // insert detection with max confidence
            IntermediateResult[] a = new IntermediateResult[pq.size()];
            IntermediateResult[] detections = pq.toArray(a);
            Rectangle rec = detections[0].getLocation();
            retClasses.add(detections[0].id);
            retProbs.add(detections[0].confidence);
            retBB.add(new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight()));
            pq.clear();
            for (int j = 1; j < detections.length; j++) {
                IntermediateResult detection = detections[j];
                Rectangle location = detection.getLocation();
                if (boxIou(rec, location) < nmsThreshold) {
                    pq.add(detection);
                }
            }
        }
    }
    return new DetectedObjects(retClasses, retProbs, retBB);
}
Also used : ArrayList(java.util.ArrayList) Rectangle(ai.djl.modality.cv.output.Rectangle) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) PriorityQueue(java.util.PriorityQueue) BoundingBox(ai.djl.modality.cv.output.BoundingBox)

Example 5 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class YoloV5Translator method processFromBoxOutput.

private DetectedObjects processFromBoxOutput(NDList list) {
    float[] flattened = list.get(0).toFloatArray();
    ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();
    int sizeClasses = classes.size();
    int stride = 5 + sizeClasses;
    int size = flattened.length / stride;
    for (int i = 0; i < size; i++) {
        int indexBase = i * stride;
        float maxClass = 0;
        int maxIndex = 0;
        for (int c = 0; c < sizeClasses; c++) {
            if (flattened[indexBase + c + 5] > maxClass) {
                maxClass = flattened[indexBase + c + 5];
                maxIndex = c;
            }
        }
        float score = maxClass * flattened[indexBase + 4];
        if (score > threshold) {
            float xPos = flattened[indexBase];
            float yPos = flattened[indexBase + 1];
            float w = flattened[indexBase + 2];
            float h = flattened[indexBase + 3];
            Rectangle rect = new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
            intermediateResults.add(new IntermediateResult(classes.get(maxIndex), score, maxIndex, rect));
        }
    }
    return nms(intermediateResults);
}
Also used : ArrayList(java.util.ArrayList) Rectangle(ai.djl.modality.cv.output.Rectangle)

Aggregations

Rectangle (ai.djl.modality.cv.output.Rectangle)21 DetectedObjects (ai.djl.modality.cv.output.DetectedObjects)12 BoundingBox (ai.djl.modality.cv.output.BoundingBox)11 ArrayList (java.util.ArrayList)11 NDArray (ai.djl.ndarray.NDArray)7 Point (ai.djl.modality.cv.output.Point)5 List (java.util.List)5 Image (ai.djl.modality.cv.Image)4 NDList (ai.djl.ndarray.NDList)4 Path (java.nio.file.Path)4 FaceDetectedObjects (ai.djl.examples.detection.domain.FaceDetectedObjects)3 Artifact (ai.djl.repository.Artifact)3 PairList (ai.djl.util.PairList)3 Landmark (ai.djl.modality.cv.output.Landmark)2 ProgressBar (ai.djl.training.util.ProgressBar)2 SuppressLint (android.annotation.SuppressLint)2 Canvas (android.graphics.Canvas)2 Paint (android.graphics.Paint)2 Reader (java.io.Reader)2 Type (java.lang.reflect.Type)2