use of ai.djl.examples.detection.domain.Landmark in project djl-demo by deepjavalibrary.
the class FaceDetectionTranslator method processOutput.
@Override
public FaceDetectedObjects processOutput(TranslatorContext ctx, NDList list) {
double[][] priors = this.boxRecover(width, height, scales, steps);
NDArray loc = list.get(0);
float[] locFloat = loc.toFloatArray();
double[][] boxes = this.decodeBoxes(locFloat, priors, variance);
NDArray conf = list.get(1);
float[] scores = this.decodeConf(conf);
NDArray landms = list.get(2);
List<List<Point>> landmsList = this.decodeLandm(landms.toFloatArray(), priors, variance, width, height);
PriorityQueue<FaceObject> pq = new PriorityQueue<FaceObject>(10, new Comparator<FaceObject>() {
@Override
public int compare(final FaceObject lhs, final FaceObject rhs) {
return Double.compare(rhs.getScore(), lhs.getScore());
}
});
for (int i = 0; i < scores.length; i++) {
if (scores[i] > this.confThresh) {
BoundingBox rect = new Rectangle(boxes[i][0], boxes[i][1], boxes[i][2] - boxes[i][0], boxes[i][3] - boxes[i][1]);
FaceObject faceObject = new FaceObject(scores[i], rect, landmsList.get(i));
pq.add(faceObject);
}
}
ArrayList<FaceObject> topKArrayList = new ArrayList<FaceObject>();
int index = 0;
while (pq.size() > 0) {
FaceObject faceObject = pq.poll();
if (index >= this.topK) {
break;
}
topKArrayList.add(faceObject);
}
ArrayList<FaceObject> nmsList = this.nms(topKArrayList, this.nmsThresh);
List<String> classNames = new ArrayList<String>();
List<Double> probabilities = new ArrayList<Double>();
List<BoundingBox> boundingBoxes = new ArrayList<BoundingBox>();
List<Landmark> landmarks = new ArrayList<Landmark>();
for (int i = 0; i < nmsList.size(); i++) {
FaceObject faceObject = nmsList.get(i);
classNames.add(new String("Face"));
probabilities.add((double) faceObject.getScore());
boundingBoxes.add(faceObject.getBoundingBox());
List<Point> keyPoints = faceObject.getKeyPoints();
Landmark landmark = new Landmark(keyPoints);
landmarks.add(landmark);
}
return new FaceDetectedObjects(classNames, probabilities, boundingBoxes, landmarks);
}
Aggregations