use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class PikachuDetection method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train");
break;
case TEST:
usagePath = Paths.get("test");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
Type mapType = new TypeToken<Map<String, List<Float>>>() {
}.getType();
Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
String imgName = entry.getKey();
imagePaths.add(usagePath.resolve(imgName));
List<Float> label = entry.getValue();
long objectClass = label.get(4).longValue();
Rectangle objectLocation = new Rectangle(new Point(label.get(5), label.get(6)), label.get(7), label.get(8));
labels.add(objectClass, objectLocation);
}
}
prepared = true;
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class PtSsdTranslator method processOutput.
/**
* {@inheritDoc}
*/
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
double scaleXY = 0.1;
double scaleWH = 0.2;
// kill the 1st prediction as not needed
NDArray prob = list.get(1).swapAxes(0, 1).softmax(1).get(":, 1:");
prob = NDArrays.stack(new NDList(prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[] { 1 })));
NDArray boundingBoxes = list.get(0).swapAxes(0, 1);
NDArray bbWH = boundingBoxes.get(":, 2:").mul(scaleWH).exp().mul(boxRecover.get(":, 2:"));
NDArray bbXY = boundingBoxes.get(":, :2").mul(scaleXY).mul(boxRecover.get(":, 2:")).add(boxRecover.get(":, :2")).sub(bbWH.mul(0.5f));
boundingBoxes = NDArrays.concat(new NDList(bbXY, bbWH), 1);
// filter the result below the threshold
NDArray cutOff = prob.get(1).gte(threshold);
boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
prob = prob.booleanMask(cutOff, 1);
// start categorical filtering
long[] order = prob.get(1).argSort().toLongArray();
double desiredIoU = 0.45;
prob = prob.transpose();
List<String> retNames = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();
Map<Integer, List<BoundingBox>> recorder = new ConcurrentHashMap<>();
for (int i = order.length - 1; i >= 0; i--) {
long currMaxLoc = order[i];
float[] classProb = prob.get(currMaxLoc).toFloatArray();
int classId = (int) classProb[0];
double probability = classProb[1];
double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray();
Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
List<BoundingBox> boxes = recorder.getOrDefault(classId, new ArrayList<>());
boolean belowIoU = true;
for (BoundingBox box : boxes) {
if (box.getIoU(rect) > desiredIoU) {
belowIoU = false;
break;
}
}
if (belowIoU) {
boxes.add(rect);
recorder.put(classId, boxes);
String className = classes.get(classId);
retNames.add(className);
retProbs.add(probability);
retBB.add(rect);
}
}
return new DetectedObjects(retNames, retProbs, retBB);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class FaceDetectionTranslator method processOutput.
/**
* {@inheritDoc}
*/
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
NDManager manager = ctx.getNDManager();
double scaleXY = variance[0];
double scaleWH = variance[1];
NDArray prob = list.get(1).get(":, 1:");
prob = NDArrays.stack(new NDList(prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[] { 1 })));
NDArray boxRecover = boxRecover(manager, width, height, scales, steps);
NDArray boundingBoxes = list.get(0);
NDArray bbWH = boundingBoxes.get(":, 2:").mul(scaleWH).exp().mul(boxRecover.get(":, 2:"));
NDArray bbXY = boundingBoxes.get(":, :2").mul(scaleXY).mul(boxRecover.get(":, 2:")).add(boxRecover.get(":, :2")).sub(bbWH.mul(0.5f));
boundingBoxes = NDArrays.concat(new NDList(bbXY, bbWH), 1);
NDArray landms = list.get(2);
landms = decodeLandm(landms, boxRecover, scaleXY);
// filter the result below the threshold
NDArray cutOff = prob.get(1).gt(confThresh);
boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
landms = landms.transpose().booleanMask(cutOff, 1).transpose();
prob = prob.booleanMask(cutOff, 1);
// start categorical filtering
long[] order = prob.get(1).argSort().get(":" + topK).toLongArray();
prob = prob.transpose();
List<String> retNames = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();
Map<Integer, List<BoundingBox>> recorder = new ConcurrentHashMap<>();
for (int i = order.length - 1; i >= 0; i--) {
long currMaxLoc = order[i];
float[] classProb = prob.get(currMaxLoc).toFloatArray();
int classId = (int) classProb[0];
double probability = classProb[1];
double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray();
double[] landmsArr = landms.get(currMaxLoc).toDoubleArray();
Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
List<BoundingBox> boxes = recorder.getOrDefault(classId, new ArrayList<>());
boolean belowIoU = true;
for (BoundingBox box : boxes) {
if (box.getIoU(rect) > nmsThresh) {
belowIoU = false;
break;
}
}
if (belowIoU) {
List<Point> keyPoints = new ArrayList<>();
for (int j = 0; j < 5; j++) {
// 5 face landmarks
double x = landmsArr[j * 2];
double y = landmsArr[j * 2 + 1];
keyPoints.add(new Point(x * width, y * height));
}
Landmark landmark = new Landmark(boxArr[0], boxArr[1], boxArr[2], boxArr[3], keyPoints);
boxes.add(landmark);
recorder.put(classId, boxes);
// classes.get(classId)
String className = "Face";
retNames.add(className);
retProbs.add(probability);
retBB.add(landmark);
}
}
return new DetectedObjects(retNames, retProbs, retBB);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class MaskDetectionTest method getSubImage.
private static Image getSubImage(Image img, BoundingBox box) {
Rectangle rect = box.getBounds();
int width = img.getWidth();
int height = img.getHeight();
int[] squareBox = extendSquare(rect.getX() * width, rect.getY() * height, rect.getWidth() * width, rect.getHeight() * height, 0.18);
return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);
}
use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.
the class OCRTest method getSubImage.
private static Image getSubImage(Image img, BoundingBox box) {
Rectangle rect = box.getBounds();
double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
int width = img.getWidth();
int height = img.getHeight();
int[] recovered = { (int) (extended[0] * width), (int) (extended[1] * height), (int) (extended[2] * width), (int) (extended[3] * height) };
return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
}
Aggregations