Search in sources :

Example 6 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class PikachuDetection method prepare.

/**
 * {@inheritDoc}
 */
@Override
public void prepare(Progress progress) throws IOException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path root = mrl.getRepository().getResourceDirectory(artifact);
    Path usagePath;
    switch(usage) {
        case TRAIN:
            usagePath = Paths.get("train");
            break;
        case TEST:
            usagePath = Paths.get("test");
            break;
        case VALIDATION:
        default:
            throw new UnsupportedOperationException("Validation data not available.");
    }
    usagePath = root.resolve(usagePath);
    Path indexFile = usagePath.resolve("index.file");
    try (Reader reader = Files.newBufferedReader(indexFile)) {
        Type mapType = new TypeToken<Map<String, List<Float>>>() {
        }.getType();
        Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
        for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
            String imgName = entry.getKey();
            imagePaths.add(usagePath.resolve(imgName));
            List<Float> label = entry.getValue();
            long objectClass = label.get(4).longValue();
            Rectangle objectLocation = new Rectangle(new Point(label.get(5), label.get(6)), label.get(7), label.get(8));
            labels.add(objectClass, objectLocation);
        }
    }
    prepared = true;
}
Also used : Path(java.nio.file.Path) Rectangle(ai.djl.modality.cv.output.Rectangle) Reader(java.io.Reader) Point(ai.djl.modality.cv.output.Point) Artifact(ai.djl.repository.Artifact) Type(java.lang.reflect.Type) ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList) List(java.util.List) Map(java.util.Map)

Example 7 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class PtSsdTranslator method processOutput.

/**
 * {@inheritDoc}
 */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
    double scaleXY = 0.1;
    double scaleWH = 0.2;
    // kill the 1st prediction as not needed
    NDArray prob = list.get(1).swapAxes(0, 1).softmax(1).get(":, 1:");
    prob = NDArrays.stack(new NDList(prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[] { 1 })));
    NDArray boundingBoxes = list.get(0).swapAxes(0, 1);
    NDArray bbWH = boundingBoxes.get(":, 2:").mul(scaleWH).exp().mul(boxRecover.get(":, 2:"));
    NDArray bbXY = boundingBoxes.get(":, :2").mul(scaleXY).mul(boxRecover.get(":, 2:")).add(boxRecover.get(":, :2")).sub(bbWH.mul(0.5f));
    boundingBoxes = NDArrays.concat(new NDList(bbXY, bbWH), 1);
    // filter the result below the threshold
    NDArray cutOff = prob.get(1).gte(threshold);
    boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
    prob = prob.booleanMask(cutOff, 1);
    // start categorical filtering
    long[] order = prob.get(1).argSort().toLongArray();
    double desiredIoU = 0.45;
    prob = prob.transpose();
    List<String> retNames = new ArrayList<>();
    List<Double> retProbs = new ArrayList<>();
    List<BoundingBox> retBB = new ArrayList<>();
    Map<Integer, List<BoundingBox>> recorder = new ConcurrentHashMap<>();
    for (int i = order.length - 1; i >= 0; i--) {
        long currMaxLoc = order[i];
        float[] classProb = prob.get(currMaxLoc).toFloatArray();
        int classId = (int) classProb[0];
        double probability = classProb[1];
        double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray();
        Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
        List<BoundingBox> boxes = recorder.getOrDefault(classId, new ArrayList<>());
        boolean belowIoU = true;
        for (BoundingBox box : boxes) {
            if (box.getIoU(rect) > desiredIoU) {
                belowIoU = false;
                break;
            }
        }
        if (belowIoU) {
            boxes.add(rect);
            recorder.put(classId, boxes);
            String className = classes.get(classId);
            retNames.add(className);
            retProbs.add(probability);
            retBB.add(rect);
        }
    }
    return new DetectedObjects(retNames, retProbs, retBB);
}
Also used : NDList(ai.djl.ndarray.NDList) ArrayList(java.util.ArrayList) Rectangle(ai.djl.modality.cv.output.Rectangle) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) BoundingBox(ai.djl.modality.cv.output.BoundingBox) NDArray(ai.djl.ndarray.NDArray) NDList(ai.djl.ndarray.NDList) ArrayList(java.util.ArrayList) List(java.util.List) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap)

Example 8 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class FaceDetectionTranslator method processOutput.

/**
 * {@inheritDoc}
 */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
    NDManager manager = ctx.getNDManager();
    double scaleXY = variance[0];
    double scaleWH = variance[1];
    NDArray prob = list.get(1).get(":, 1:");
    prob = NDArrays.stack(new NDList(prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[] { 1 })));
    NDArray boxRecover = boxRecover(manager, width, height, scales, steps);
    NDArray boundingBoxes = list.get(0);
    NDArray bbWH = boundingBoxes.get(":, 2:").mul(scaleWH).exp().mul(boxRecover.get(":, 2:"));
    NDArray bbXY = boundingBoxes.get(":, :2").mul(scaleXY).mul(boxRecover.get(":, 2:")).add(boxRecover.get(":, :2")).sub(bbWH.mul(0.5f));
    boundingBoxes = NDArrays.concat(new NDList(bbXY, bbWH), 1);
    NDArray landms = list.get(2);
    landms = decodeLandm(landms, boxRecover, scaleXY);
    // filter the result below the threshold
    NDArray cutOff = prob.get(1).gt(confThresh);
    boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
    landms = landms.transpose().booleanMask(cutOff, 1).transpose();
    prob = prob.booleanMask(cutOff, 1);
    // start categorical filtering
    long[] order = prob.get(1).argSort().get(":" + topK).toLongArray();
    prob = prob.transpose();
    List<String> retNames = new ArrayList<>();
    List<Double> retProbs = new ArrayList<>();
    List<BoundingBox> retBB = new ArrayList<>();
    Map<Integer, List<BoundingBox>> recorder = new ConcurrentHashMap<>();
    for (int i = order.length - 1; i >= 0; i--) {
        long currMaxLoc = order[i];
        float[] classProb = prob.get(currMaxLoc).toFloatArray();
        int classId = (int) classProb[0];
        double probability = classProb[1];
        double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray();
        double[] landmsArr = landms.get(currMaxLoc).toDoubleArray();
        Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
        List<BoundingBox> boxes = recorder.getOrDefault(classId, new ArrayList<>());
        boolean belowIoU = true;
        for (BoundingBox box : boxes) {
            if (box.getIoU(rect) > nmsThresh) {
                belowIoU = false;
                break;
            }
        }
        if (belowIoU) {
            List<Point> keyPoints = new ArrayList<>();
            for (int j = 0; j < 5; j++) {
                // 5 face landmarks
                double x = landmsArr[j * 2];
                double y = landmsArr[j * 2 + 1];
                keyPoints.add(new Point(x * width, y * height));
            }
            Landmark landmark = new Landmark(boxArr[0], boxArr[1], boxArr[2], boxArr[3], keyPoints);
            boxes.add(landmark);
            recorder.put(classId, boxes);
            // classes.get(classId)
            String className = "Face";
            retNames.add(className);
            retProbs.add(probability);
            retBB.add(landmark);
        }
    }
    return new DetectedObjects(retNames, retProbs, retBB);
}
Also used : NDList(ai.djl.ndarray.NDList) ArrayList(java.util.ArrayList) Rectangle(ai.djl.modality.cv.output.Rectangle) BoundingBox(ai.djl.modality.cv.output.BoundingBox) NDList(ai.djl.ndarray.NDList) ArrayList(java.util.ArrayList) List(java.util.List) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) Landmark(ai.djl.modality.cv.output.Landmark) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) Point(ai.djl.modality.cv.output.Point) Point(ai.djl.modality.cv.output.Point) NDArray(ai.djl.ndarray.NDArray) NDManager(ai.djl.ndarray.NDManager)

Example 9 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class MaskDetectionTest method getSubImage.

private static Image getSubImage(Image img, BoundingBox box) {
    Rectangle rect = box.getBounds();
    int width = img.getWidth();
    int height = img.getHeight();
    int[] squareBox = extendSquare(rect.getX() * width, rect.getY() * height, rect.getWidth() * width, rect.getHeight() * height, 0.18);
    return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);
}
Also used : Rectangle(ai.djl.modality.cv.output.Rectangle)

Example 10 with Rectangle

use of ai.djl.modality.cv.output.Rectangle in project djl by deepjavalibrary.

the class OCRTest method getSubImage.

private static Image getSubImage(Image img, BoundingBox box) {
    Rectangle rect = box.getBounds();
    double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
    int width = img.getWidth();
    int height = img.getHeight();
    int[] recovered = { (int) (extended[0] * width), (int) (extended[1] * height), (int) (extended[2] * width), (int) (extended[3] * height) };
    return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
}
Also used : Rectangle(ai.djl.modality.cv.output.Rectangle)

Aggregations

Rectangle (ai.djl.modality.cv.output.Rectangle)21 DetectedObjects (ai.djl.modality.cv.output.DetectedObjects)12 BoundingBox (ai.djl.modality.cv.output.BoundingBox)11 ArrayList (java.util.ArrayList)11 NDArray (ai.djl.ndarray.NDArray)7 Point (ai.djl.modality.cv.output.Point)5 List (java.util.List)5 Image (ai.djl.modality.cv.Image)4 NDList (ai.djl.ndarray.NDList)4 Path (java.nio.file.Path)4 FaceDetectedObjects (ai.djl.examples.detection.domain.FaceDetectedObjects)3 Artifact (ai.djl.repository.Artifact)3 PairList (ai.djl.util.PairList)3 Landmark (ai.djl.modality.cv.output.Landmark)2 ProgressBar (ai.djl.training.util.ProgressBar)2 SuppressLint (android.annotation.SuppressLint)2 Canvas (android.graphics.Canvas)2 Paint (android.graphics.Paint)2 Reader (java.io.Reader)2 Type (java.lang.reflect.Type)2