Search in sources :

Example 21 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl-demo by deepjavalibrary.

the class FaceDetectionTranslator method processOutput.

@Override
public FaceDetectedObjects processOutput(TranslatorContext ctx, NDList list) {
    double[][] priors = this.boxRecover(width, height, scales, steps);
    NDArray loc = list.get(0);
    float[] locFloat = loc.toFloatArray();
    double[][] boxes = this.decodeBoxes(locFloat, priors, variance);
    NDArray conf = list.get(1);
    float[] scores = this.decodeConf(conf);
    NDArray landms = list.get(2);
    List<List<Point>> landmsList = this.decodeLandm(landms.toFloatArray(), priors, variance, width, height);
    PriorityQueue<FaceObject> pq = new PriorityQueue<FaceObject>(10, new Comparator<FaceObject>() {

        @Override
        public int compare(final FaceObject lhs, final FaceObject rhs) {
            return Double.compare(rhs.getScore(), lhs.getScore());
        }
    });
    for (int i = 0; i < scores.length; i++) {
        if (scores[i] > this.confThresh) {
            BoundingBox rect = new Rectangle(boxes[i][0], boxes[i][1], boxes[i][2] - boxes[i][0], boxes[i][3] - boxes[i][1]);
            FaceObject faceObject = new FaceObject(scores[i], rect, landmsList.get(i));
            pq.add(faceObject);
        }
    }
    ArrayList<FaceObject> topKArrayList = new ArrayList<FaceObject>();
    int index = 0;
    while (pq.size() > 0) {
        FaceObject faceObject = pq.poll();
        if (index >= this.topK) {
            break;
        }
        topKArrayList.add(faceObject);
    }
    ArrayList<FaceObject> nmsList = this.nms(topKArrayList, this.nmsThresh);
    List<String> classNames = new ArrayList<String>();
    List<Double> probabilities = new ArrayList<Double>();
    List<BoundingBox> boundingBoxes = new ArrayList<BoundingBox>();
    List<Landmark> landmarks = new ArrayList<Landmark>();
    for (int i = 0; i < nmsList.size(); i++) {
        FaceObject faceObject = nmsList.get(i);
        classNames.add(new String("Face"));
        probabilities.add((double) faceObject.getScore());
        boundingBoxes.add(faceObject.getBoundingBox());
        List<Point> keyPoints = faceObject.getKeyPoints();
        Landmark landmark = new Landmark(keyPoints);
        landmarks.add(landmark);
    }
    return new FaceDetectedObjects(classNames, probabilities, boundingBoxes, landmarks);
}
Also used : Rectangle(ai.djl.modality.cv.output.Rectangle) ArrayList(java.util.ArrayList) FaceObject(ai.djl.examples.detection.domain.FaceObject) BoundingBox(ai.djl.modality.cv.output.BoundingBox) NDList(ai.djl.ndarray.NDList) ArrayList(java.util.ArrayList) List(java.util.List) Landmark(ai.djl.examples.detection.domain.Landmark) FaceDetectedObjects(ai.djl.examples.detection.domain.FaceDetectedObjects) Point(ai.djl.modality.cv.output.Point) PriorityQueue(java.util.PriorityQueue) Point(ai.djl.modality.cv.output.Point) NDArray(ai.djl.ndarray.NDArray)

Aggregations

Rectangle (ai.djl.modality.cv.output.Rectangle)21 DetectedObjects (ai.djl.modality.cv.output.DetectedObjects)12 BoundingBox (ai.djl.modality.cv.output.BoundingBox)11 ArrayList (java.util.ArrayList)11 NDArray (ai.djl.ndarray.NDArray)7 Point (ai.djl.modality.cv.output.Point)5 List (java.util.List)5 Image (ai.djl.modality.cv.Image)4 NDList (ai.djl.ndarray.NDList)4 Path (java.nio.file.Path)4 FaceDetectedObjects (ai.djl.examples.detection.domain.FaceDetectedObjects)3 Artifact (ai.djl.repository.Artifact)3 PairList (ai.djl.util.PairList)3 Landmark (ai.djl.modality.cv.output.Landmark)2 ProgressBar (ai.djl.training.util.ProgressBar)2 SuppressLint (android.annotation.SuppressLint)2 Canvas (android.graphics.Canvas)2 Paint (android.graphics.Paint)2 Reader (java.io.Reader)2 Type (java.lang.reflect.Type)2